Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
6c18971f00 feat: 初步接入米家小爱音箱 2025-01-25 02:30:00 +08:00
71 changed files with 1351 additions and 1611 deletions

View File

@@ -1,12 +1,14 @@
<p align="center">
![logo](https://github.com/user-attachments/assets/07649e07-3b8e-4feb-9aa9-bf13af4f3476)
<p align="center">
<img src="https://github.com/user-attachments/assets/de10f24d-cd64-433a-90b8-16c0a60de24a" width=500>
</p>
<div align="center">
<h1>AstrBot</h1>
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
@@ -14,8 +16,9 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
@@ -35,7 +38,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM因此无法在聊天页使用大模型。
## ✨ 使用方式
@@ -64,31 +67,19 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ⚡ 消息平台支持情况
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
| QQ 官方API | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| 微信(企业微信) | 🚧 | 计划内 | - |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| 飞书 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
# 🦌 接下来的路线图
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
@@ -140,21 +131,8 @@ _✨ 内置 Web Chat在线与机器人交互 ✨_
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## Sponsors
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
@@ -166,6 +144,5 @@ _✨ 内置 Web Chat在线与机器人交互 ✨_
4. TTS
-->
_私は、高性能ですから!_
_アトリは、高性能ですから!_

View File

@@ -6,7 +6,6 @@ from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent
@@ -32,6 +31,5 @@ __all__ = [
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent',
'on_llm_response'
'after_message_sent'
]

View File

@@ -2,5 +2,4 @@ from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
from astrbot.core.platform.register import register_platform_adapter

View File

@@ -1,2 +1,2 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData

View File

@@ -2,7 +2,7 @@ import os
import json
import logging
import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from .default import DEFAULT_CONFIG
from typing import Dict
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
@@ -13,72 +13,29 @@ class RateLimitStrategy(enum.Enum):
DISCARD = "discard"
class AstrBotConfig(dict):
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema将会通过 schema 解析出 default_config此时传入的 default_config 会被忽略。
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None
):
def __init__(self):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
'''不存在时载入默认配置'''
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
with open(config_path, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(default_config, conf)
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
@@ -104,7 +61,7 @@ class AstrBotConfig(dict):
'''
if replace_config:
self.update(replace_config)
with open(self.config_path, "w", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
@@ -124,4 +81,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(self.config_path)
return os.path.exists(ASTRBOT_CONFIG_PATH)

View File

@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.17"
VERSION = "3.4.11"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -24,20 +24,13 @@ DEFAULT_CONFIG = {
"wl_ignore_admin_on_friend": True,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": [],
"segmented_reply": {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"regex": ".*?[。?!~…]+|.+$"
}
"path_mapping": []
},
"provider": [],
"provider_settings": {
"enable": True,
"wake_prefix": "",
"web_search": False,
"web_search_link": False,
"identifier": False,
"datetime_system_prompt": True,
"default_personality": "default",
@@ -47,23 +40,6 @@ DEFAULT_CONFIG = {
"enable": False,
"provider_id": "",
},
"provider_tts_settings": {
"enable": False,
"provider_id": "",
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_prompt": "Please describe the image using Chinese.",
"active_reply": {
"enable": False,
"method": "possibility_reply",
"possibility_reply": 0.1,
"prompt": "",
},
"put_history_to_prompt": True,
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -86,7 +62,7 @@ DEFAULT_CONFIG = {
"persona": [
{
"name": "default",
"prompt": "",
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
@@ -119,15 +95,27 @@ CONFIG_METADATA_2 = {
"ws_reverse_host": "",
"ws_reverse_port": 6199,
},
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
"gewechat(微信)": {
"id": "gwchat",
"type": "gewechat",
"enable": False,
"base_url": "http://localhost:2531",
"nickname": "soulter",
"host": "这里填写你的局域网IP或者公网服务器IP",
"host": "localhost",
"port": 11451,
},
"mispeaker(小爱音箱)": {
"id": "mispeaker",
"type": "mispeaker",
"enable": False,
"username": "",
"password": "",
"did": "",
"activate_word": "测试",
"deactivate_word": "停止",
"interval": 1,
},
},
"items": {
"id": {
@@ -201,31 +189,6 @@ CONFIG_METADATA_2 = {
},
},
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
"items": {
"enable": {
"description": "启用分段回复",
"type": "bool",
},
"only_llm_result": {
"description": "仅对 LLM 结果分段",
"type": "bool",
},
"interval": {
"description": "随机间隔时间(秒)",
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"regex": {
"description": "正则表达式",
"type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
},
},
},
"reply_prefix": {
"description": "回复前缀",
"type": "string",
@@ -243,9 +206,8 @@ CONFIG_METADATA_2 = {
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID当一条消息没通过白名单时会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
"items": {"type": "int"},
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID当一条消息没通过白名单时会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -273,7 +235,6 @@ CONFIG_METADATA_2 = {
"path_mapping": {
"description": "路径映射",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
}
@@ -332,7 +293,7 @@ CONFIG_METADATA_2 = {
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.openai.com/v1",
"api_base": "",
"model_config": {
"model": "gpt-4o-mini",
},
@@ -421,24 +382,8 @@ CONFIG_METADATA_2 = {
"type": "openai_whisper_selfhost",
"model": "tiny",
},
"openai_tts(API)": {
"id": "openai_tts",
"type": "openai_tts_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "tts-1",
"openai-tts-voice": "alloy",
"timeout": "20",
},
},
"items": {
"openai-tts-voice": {
"description": "voice",
"type": "string",
"obvious_hint": True,
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
},
"whisper_hint": {
"description": "本地部署 Whisper 模型须知",
"type": "string",
@@ -469,7 +414,7 @@ CONFIG_METADATA_2 = {
"api_base": {
"description": "API Base URL",
"type": "string",
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现 404 报错,可以尝试在地址末尾加上 `/v1`。",
"obvious_hint": True,
},
"base_model_path": {
@@ -555,25 +500,16 @@ CONFIG_METADATA_2 = {
"web_search": {
"description": "启用网页搜索",
"type": "bool",
"obvious_hint": True,
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
},
"web_search_link": {
"description": "网页搜索引用链接",
"type": "bool",
"obvious_hint": True,
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
},
"identifier": {
"description": "启动识别群员",
"type": "bool",
"obvious_hint": True,
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
},
"datetime_system_prompt": {
"description": "启用日期时间系统提示",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
},
"default_personality": {
@@ -615,15 +551,15 @@ CONFIG_METADATA_2 = {
"begin_dialogs": {
"description": "预设对话",
"type": "list",
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色之后按回车",
"items": {},
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
"obvious_hint": True,
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。对话需要成对(用户和助手),输入完一个角色之后按回车",
"items": {},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
"obvious_hint": True,
},
},
@@ -645,87 +581,6 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_tts_settings": {
"description": "文本转语音(TTS)",
"type": "object",
"items": {
"enable": {
"description": "启用文本转语音(TTS)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个TTS提供商",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
"provider_ltm_settings": {
"description": "聊天记忆增强(Beta)",
"type": "object",
"items": {
"group_icl_enable": {
"description": "群聊内记录各群员对话",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
},
"group_message_max_cnt": {
"description": "群聊消息最大数量",
"type": "int",
"obvious_hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需要模型支持)",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string"
},
"active_reply": {
"description": "主动回复",
"type": "object",
"items": {
"enable": {
"description": "启用主动回复",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会根据触发概率主动回复群聊内的对话。",
},
"method": {
"description": "回复方法",
"type": "string",
"options": ["possibility_reply"],
"hint": "回复方法。possibility_reply 为根据概率回复",
},
"possibility_reply": {
"description": "回复概率",
"type": "float",
"obvious_hint": True,
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
},
"prompt": {
"description": "提示词",
"type": "string",
"obvious_hint": True,
"hint": "提示词。当提示词为空时如果触发回复prompt是触发的消息的内容否则是提示词。此项可以和定时回复暂未实现配合使用。",
},
},
},
"put_history_to_prompt": {
"description": "将群聊历史记录作为 prompt",
"type": "bool",
"obvious_hint": True,
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复建议启用模型能够更好地完成回复任务。",
}
},
},
},
},
"misc_config_group": {
@@ -735,8 +590,7 @@ CONFIG_METADATA_2 = {
"description": "机器人唤醒前缀",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`则内置指令help等将需要通过您的唤醒前缀来触发。",
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
},
"t2i": {
"description": "文本转图像",
@@ -746,7 +600,7 @@ CONFIG_METADATA_2 = {
"admins_id": {
"description": "管理员 ID",
"type": "list",
"items": {"type": "string"},
"items": {"type": "int"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {

View File

@@ -7,6 +7,7 @@ from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager

View File

@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str]
path: T.Optional[str] # 用这个
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
@@ -306,7 +306,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: T.Union[str, int]
id: int
text: T.Optional[str] = ""
qq: T.Optional[int] = 0
time: T.Optional[int] = 0

View File

@@ -13,10 +13,12 @@ class MessageChain():
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''添加一条文本消息到消息链 `chain` 中。
@@ -75,6 +77,16 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''设置是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
@@ -101,6 +113,7 @@ class MessageEventResult(MessageChain):
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
@@ -126,7 +139,7 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
@@ -135,15 +148,5 @@ class MessageEventResult(MessageChain):
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult

View File

@@ -21,10 +21,6 @@ class DifyRequestSubStage(Stage):
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
return
if provider.meta().type != "dify":
return

View File

@@ -51,7 +51,7 @@ class LLMRequestSubStage(Stage):
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt and not req.image_urls:
if not req.prompt:
return
# 执行请求 LLM 前事件。
@@ -68,15 +68,6 @@ class LLMRequestSubStage(Stage):
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':

View File

@@ -39,11 +39,8 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()

View File

@@ -37,11 +37,7 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
logger.debug(f"llm request -> {resp.prompt}")
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
_t = True
yield
if not _t:
yield
else:
yield
@@ -53,11 +49,6 @@ class ProcessStage(Stage):
if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
match provider.meta().type:
case "dify":
async for _ in self.dify_request_sub_stage.process(event):

View File

@@ -1,10 +1,7 @@
import random
import asyncio
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -12,19 +9,6 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间:{self.interval}")
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -32,16 +16,7 @@ class RespondStage(Stage):
return
if len(result.chain) > 0:
await event._pre_send()
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
# 分段回复
for comp in result.chain:
await event.send(MessageChain([comp]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
else:
await event.send(result)
await event._post_send()
await event.send(result)
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)

View File

@@ -1,13 +1,11 @@
import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record
from astrbot.core.message.components import Plain, Image, At, Reply
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -18,12 +16,7 @@ class ResultDecorateStage:
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
self.t2i = ctx.astrbot_config['t2i']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -38,53 +31,10 @@ class ResultDecorateStage:
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
for comp in result.chain:
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
if self.use_tts and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + comp.text)
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
result.chain.insert(0, Plain(self.reply_prefix))
# 文本转图片
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
@@ -104,7 +54,7 @@ class ResultDecorateStage:
result.chain = [Image.fromURL(url)]
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
result.chain.insert(0, At(qq=event.get_sender_id()))
if self.reply_with_quote:
result.chain.insert(0, Reply(id=event.message_obj.message_id))

View File

@@ -18,11 +18,6 @@ class WhitelistCheckStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
# 白名单检查未启用
return
if len(self.whitelist) == 0:
# 白名单为空,不检查
return
if event.get_platform_name() == 'webchat':

View File

@@ -179,15 +179,6 @@ class AstrMessageEvent(abc.ABC):
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
'''设置消息事件的结果。
@@ -296,7 +287,6 @@ class AstrMessageEvent(abc.ABC):
def request_llm(
self,
prompt: str,
func_tool_manager = None,
session_id: str = None,
image_urls: List[str] = None,
contexts: List = None,
@@ -312,13 +302,11 @@ class AstrMessageEvent(abc.ABC):
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
'''
return ProviderRequest(
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
func_tool = func_tool_manager,
contexts = contexts,
system_prompt = system_prompt
)

View File

@@ -24,12 +24,11 @@ class PlatformManager():
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
try:
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
except BaseException:
logger.warning("当前 astrbot 已不维护 vchat 的接入,如有需要请 pip 安装 vchat 然后重启")
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "mispeaker":
from .sources.mispeaker.mispeaker_adapter import MiSpeakerPlatformAdapter # noqa: F401
async def initialize(self):

View File

@@ -3,7 +3,7 @@ import random
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -20,18 +20,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, (Image, Record)):
if isinstance(segment, Image):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_base64 = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
else:
bs64_data = file_to_base64(segment.file)
image_base64 = file_to_base64(image_file_path)
d['data'] = {
'file': bs64_data,
'file': image_base64,
}
ret.append(d)
return ret
@@ -40,5 +38,11 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
await self.bot.send(self.message_obj.raw_message, ret)
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
else:
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)

View File

@@ -2,14 +2,10 @@ import threading
import asyncio
import aiohttp
import quart
import base64
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api.message_components import Plain, Image, At
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.core.utils.io import download_image_by_url
class SimpleGewechatClient():
'''针对 Gewechat 的简单实现。
@@ -21,15 +17,9 @@ class SimpleGewechatClient():
self.base_url = base_url
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
@@ -37,19 +27,15 @@ class SimpleGewechatClient():
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.callback_url = None
self.server = quart.Quart(__name__)
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
@@ -66,87 +52,57 @@ class SimpleGewechatClient():
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
abm = AstrBotMessage()
d = data['Data']
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
msg_type = d['MsgType']
abm.message_id = str(d.get('MsgId'))
abm.session_id = from_user_name
abm.self_id = data['Wxid'] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
.replace('在群聊中@了你', '') \
.replace('在群聊中发了一段语音', '') # 真实昵称
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
# 不同消息类型
match d['MsgType']:
match msg_type:
case 1:
# 文本消息
abm.message.append(Plain(content))
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
user_real_name = d['PushContent'].split(' : ')[0] # 真实昵称
user_real_name = user_real_name.replace('在群聊中@了你', '') # trick
abm.self_id = data['Wxid'] # 机器人的 wxid
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.session_id = from_user_name
abm.sender = MessageMember(user_id, user_real_name)
abm.message = [Plain(content)]
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
abm.message_id = str(d['MsgId'])
abm.raw_message = d
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid,
content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
with open(file_path, "wb") as f:
f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
logger.info(f"abm: {abm}")
return abm
case _:
logger.error(f"未实现的消息类型: {d['MsgType']}")
return
logger.info(f"abm: {abm}")
return abm
logger.error(f"未实现的消息类型: {msg_type}")
async def callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
@@ -154,43 +110,40 @@ class SimpleGewechatClient():
if data.get('testMsg', None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}")
abm = await self._convert(data)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={
"token": self.token,
"callbackUrl": self.callback_url
"callbackUrl": callback_url
}
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob['ret'] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
logger.info(f"将在 {callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
async def start_polling(self):
# 设置回调
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host=self.host,
port=self.port,
@@ -233,8 +186,6 @@ class SimpleGewechatClient():
async def login(self):
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
if self.appid:
online = await self.check_online(self.appid)
@@ -294,7 +245,7 @@ class SimpleGewechatClient():
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{self.nickname}", appid)
sp.put(f"gewechat-appid-{nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
@@ -312,39 +263,4 @@ class SimpleGewechatClient():
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
logger.info(f"发送消息结果: {json_blob}")

View File

@@ -1,51 +0,0 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader():
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {
"Content-Type": "application/json",
"X-GEWE-TOKEN": token
}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}",
headers=self.headers,
json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {
"appId": appid,
"xml": xml,
"msgId": msg_id
}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
'''返回一个可下载的 URL'''
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {
"appId": appid,
"xml": xml,
"type": choice
}
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
json_blob = json.loads(data)
if 'fileUrl' in json_blob['data']:
return self.download_base_url + json_blob['data']['fileUrl']
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")

View File

@@ -1,24 +1,12 @@
import wave
import uuid
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
import random
import asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
@@ -46,57 +34,5 @@ class GewechatPlatformEvent(AstrMessageEvent):
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath('data/temp')
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
duration = await wav_to_tencent_silk(record_path, silk_path)
print(f"duration: {duration}, {silk_path}")
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# temp_directory = os.path.abspath('data/temp')
# record_path = os.path.abspath(record_path)
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
# with open(record_path, "rb") as f:
# record_path = f"data/temp/{uuid.uuid4()}.wav"
# with open(record_path, "wb") as f2:
# f2.write(f.read())
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
await self.client.post_voice(to_wxid, record_url, duration*1000)
await super().send(message)

View File

@@ -0,0 +1,137 @@
import os
import json
import asyncio
import aiohttp
import time
import traceback
from .miservice import MiAccount, MiNAService, MiIOService, miio_command, miio_command_help
from astrbot.core import logger
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At
class SimpleMiSpeakerClient():
'''
@author: Soulter
@references: https://github.com/yihong0618/xiaogpt/blob/main/xiaogpt/xiaogpt.py
'''
def __init__(self, config: dict):
self.username = config['username']
self.password = config['password']
self.did = config['did']
self.store = os.path.join("data", '.mi.token')
self.interval = float(config.get('interval', 1))
self.conv_query_cookies = {
'userId': '',
'deviceId': '',
'serviceToken': ''
}
self.MI_CONVERSATION_URL = "https://userprofile.mina.mi.com/device_profile/v2/conversation?source=dialogu&hardware={hardware}&timestamp={timestamp}&limit=1"
self.session = aiohttp.ClientSession()
self.activate_word = config.get('activate_word', '测试')
self.deactivate_word = config.get('deactivate_word', '停止')
self.entered = False
async def initialize(self):
account = MiAccount(self.session, self.username, self.password, self.store)
self.miio_service = MiIOService(account) # 小米设备服务
self.mina_service = MiNAService(account) # 小爱音箱服务
device = await self.get_mina_device()
self.deviceID = device['deviceID']
self.hardware = device['hardware']
with open(self.store, 'r') as f:
data = json.load(f)
self.userId = data['userId']
self.serviceToken = data['micoapi'][1]
self.conv_query_cookies['userId'] = self.userId
self.conv_query_cookies['deviceId'] = self.deviceID
self.conv_query_cookies['serviceToken'] = self.serviceToken
logger.info(f"MiSpeakerClient initialized. Conv cookies: {self.conv_query_cookies}. Hardware: {self.hardware}")
async def get_mina_device(self) -> dict:
devices = await self.mina_service.device_list()
for device in devices:
if device['miotDID'] == self.did:
logger.info(f"找到设备 {device['alias']}({device['name']}) 了!")
return device
async def get_conv(self) -> str:
# 时区请确保为北京时间
async with aiohttp.ClientSession() as session:
session.cookie_jar.update_cookies(self.conv_query_cookies)
query_ts = int(time.time())*1000
logger.debug(f"Querying conversation at {query_ts}")
async with session.get(self.MI_CONVERSATION_URL.format(hardware=self.hardware, timestamp=str(query_ts))) as resp:
json_blob = await resp.json()
if json_blob['code'] == 0:
data = json.loads(json_blob['data'])
records = data.get('records', None)
for record in records:
if record['time'] >= query_ts - self.interval*1000:
return record['query']
else:
logger.error(f"Failed to get conversation: {json_blob}")
return None
async def start_pooling(self):
while True:
await asyncio.sleep(self.interval)
try:
query = await self.get_conv()
if not query:
continue
# is wake
if query == self.activate_word:
self.entered = True
await self.stop_playing()
await self.send("我来啦!")
continue
elif query == self.deactivate_word:
self.entered = False
await self.stop_playing()
await self.send("再见,欢迎给个 Star。")
continue
if not self.entered:
continue
await self.send("")
abm = await self._convert(query)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
except BaseException as e:
traceback.print_exc()
logger.error(e)
async def _convert(self, query: str):
abm = AstrBotMessage()
abm.message = [Plain(query)]
abm.message_id = str(int(time.time()))
abm.message_str = query
abm.raw_message = query
abm.session_id = f"{self.hardware}_{self.did}_{self.username}"
abm.sender = MessageMember(self.username, "主人")
abm.self_id = f"{self.hardware}_{self.did}"
abm.type = MessageType.FRIEND_MESSAGE
return abm
async def send(self, message: str):
text = f'5 {message}'
await miio_command(self.miio_service, self.did, text, 'astrbot')
async def stop_playing(self):
text = f'3-2'
await miio_command(self.miio_service, self.did, text, 'astrbot')

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021-2022 Yonsm
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,5 @@
from .miaccount import MiAccount, MiTokenStore
from .minaservice import MiNAService
from .miioservice import MiIOService
from .miiocommand import miio_command, miio_command_help

View File

@@ -0,0 +1,135 @@
import base64
import hashlib
import json
import logging
import os
import random
import string
from urllib import parse
from aiohttp import ClientSession
from aiofiles import open as async_open
_LOGGER = logging.getLogger(__package__)
def get_random(length):
return ''.join(random.sample(string.ascii_letters + string.digits, length))
class MiTokenStore:
def __init__(self, token_path):
self.token_path = token_path
async def load_token(self):
if os.path.isfile(self.token_path):
try:
async with async_open(self.token_path) as f:
return json.loads(await f.read())
except Exception as e:
_LOGGER.exception("Exception on load token from %s: %s", self.token_path, e)
return None
async def save_token(self, token=None):
if token:
try:
async with async_open(self.token_path, 'w') as f:
await f.write(json.dumps(token, indent=2))
except Exception as e:
_LOGGER.exception("Exception on save token to %s: %s", self.token_path, e)
elif os.path.isfile(self.token_path):
os.remove(self.token_path)
class MiAccount:
def __init__(self, session: ClientSession, username, password, token_store='.mi.token'):
self.session = session
self.username = username
self.password = password
self.token_store = MiTokenStore(token_store) if isinstance(token_store, str) else token_store
self.token = None
async def login(self, sid):
if not self.token:
self.token = {'deviceId': get_random(16).upper()}
try:
resp = await self._serviceLogin(f'serviceLogin?sid={sid}&_json=true')
if resp['code'] != 0:
data = {
'_json': 'true',
'qs': resp['qs'],
'sid': resp['sid'],
'_sign': resp['_sign'],
'callback': resp['callback'],
'user': self.username,
'hash': hashlib.md5(self.password.encode()).hexdigest().upper()
}
resp = await self._serviceLogin('serviceLoginAuth2', data)
if resp['code'] != 0:
raise Exception(resp)
self.token['userId'] = resp['userId']
self.token['passToken'] = resp['passToken']
serviceToken = await self._securityTokenService(resp['location'], resp['nonce'], resp['ssecurity'])
self.token[sid] = (resp['ssecurity'], serviceToken)
if self.token_store:
await self.token_store.save_token(self.token)
return True
except Exception as e:
self.token = None
if self.token_store:
await self.token_store.save_token()
_LOGGER.exception("Exception on login %s: %s", self.username, e)
return False
async def _serviceLogin(self, uri, data=None):
headers = {'User-Agent': 'APP/com.xiaomi.mihome APPV/6.0.103 iosPassportSDK/3.9.0 iOS/14.4 miHSTS'}
cookies = {'sdkVersion': '3.9', 'deviceId': self.token['deviceId']}
if 'passToken' in self.token:
cookies['userId'] = self.token['userId']
cookies['passToken'] = self.token['passToken']
url = 'https://account.xiaomi.com/pass/' + uri
async with self.session.request('GET' if data is None else 'POST', url, data=data, cookies=cookies, headers=headers) as r:
raw = await r.read()
resp = json.loads(raw[11:])
_LOGGER.debug("%s: %s", uri, resp)
return resp
async def _securityTokenService(self, location, nonce, ssecurity):
nsec = 'nonce=' + str(nonce) + '&' + ssecurity
clientSign = base64.b64encode(hashlib.sha1(nsec.encode()).digest()).decode()
async with self.session.get(location + '&clientSign=' + parse.quote(clientSign)) as r:
serviceToken = r.cookies['serviceToken'].value
if not serviceToken:
raise Exception(await r.text())
return serviceToken
async def mi_request(self, sid, url, data, headers, relogin=True):
if self.token is None and self.token_store is not None:
self.token = await self.token_store.load_token()
if (self.token and sid in self.token) or await self.login(sid): # Ensure login
cookies = {'userId': self.token['userId'], 'serviceToken': self.token[sid][1]}
content = data(self.token, cookies) if callable(data) else data
method = 'GET' if data is None else 'POST'
_LOGGER.debug("%s %s", url, content)
async with self.session.request(method, url, data=content, cookies=cookies, headers=headers) as r:
status = r.status
if status == 200:
resp = await r.json(content_type=None)
code = resp['code']
if code == 0:
return resp
if 'auth' in resp.get('message', '').lower():
status = 401
else:
resp = await r.text()
if status == 401 and relogin:
_LOGGER.warn("Auth error on request %s %s, relogin...", url, resp)
self.token = None # Auth error, reset login
return await self.mi_request(sid, url, data, headers, False)
else:
resp = "Login failed"
raise Exception(f"Error {url}: {resp}")

View File

@@ -0,0 +1,104 @@
import json
from .miioservice import MiIOService
def twins_split(string, sep, default=None):
pos = string.find(sep)
return (string, default) if pos == -1 else (string[0:pos], string[pos+1:])
def string_to_value(string):
if string[0] in '"\'#':
return string[1:-1] if string[-1] in '"\'#' else string[1:]
elif string == 'null':
return None
elif string == 'false':
return False
elif string == 'true':
return True
elif string.isdigit():
return int(string)
try:
return float(string)
except:
return string
def miio_command_help(did=None, prefix='?'):
quote = '' if prefix == '?' else "'"
return f'\
Get Props: {prefix}<siid[-piid]>[,...]\n\
{prefix}1,1-2,1-3,1-4,2-1,2-2,3\n\
Set Props: {prefix}<siid[-piid]=[#]value>[,...]\n\
{prefix}2=60,2-1=#60,2-2=false,2-3="null",3=test\n\
Do Action: {prefix}<siid[-piid]> <arg1|[]> [...] \n\
{prefix}2 []\n\
{prefix}5 Hello\n\
{prefix}5-4 Hello 1\n\n\
Call MIoT: {prefix}<cmd=prop/get|/prop/set|action> <params>\n\
{prefix}action {quote}{{"did":"{did or "267090026"}","siid":5,"aiid":1,"in":["Hello"]}}{quote}\n\n\
Call MiIO: {prefix}/<uri> <data>\n\
{prefix}/home/device_list {quote}{{"getVirtualModel":false,"getHuamiDevices":1}}{quote}\n\n\
Devs List: {prefix}list [name=full|name_keyword] [getVirtualModel=false|true] [getHuamiDevices=0|1]\n\
{prefix}list Light true 0\n\n\
MIoT Spec: {prefix}spec [model_keyword|type_urn] [format=text|python|json]\n\
{prefix}spec\n\
{prefix}spec speaker\n\
{prefix}spec xiaomi.wifispeaker.lx04\n\
{prefix}spec urn:miot-spec-v2:device:speaker:0000A015:xiaomi-lx04:1\n\n\
MIoT Decode: {prefix}decode <ssecurity> <nonce> <data> [gzip]\n\
'
async def miio_command(service: MiIOService, did, text, prefix='?'):
cmd, arg = twins_split(text, ' ')
if cmd.startswith('/'):
return await service.miio_request(cmd, arg)
if cmd.startswith('prop') or cmd == 'action':
return await service.miot_request(cmd, json.loads(arg) if arg else None)
argv = arg.split(' ') if arg else []
argc = len(argv)
if cmd == 'list':
return await service.device_list(argc > 0 and argv[0], argc > 1 and string_to_value(argv[1]), argc > 2 and argv[2])
if cmd == 'spec':
return await service.miot_spec(argc > 0 and argv[0], argc > 1 and argv[1])
if cmd == 'decode':
return MiIOService.miot_decode(argv[0], argv[1], argv[2], argc > 3 and argv[3] == 'gzip')
if not did or not cmd or cmd == '?' or cmd == '' or cmd == 'help' or cmd == '-h' or cmd == '--help':
return miio_command_help(did, prefix)
if not did.isdigit():
devices = await service.device_list(did)
if not devices:
return "Device not found: " + did
did = devices[0]['did']
props = []
setp = True
miot = True
for item in cmd.split(','):
key, value = twins_split(item, '=')
siid, iid = twins_split(key, '-', '1')
if siid.isdigit() and iid.isdigit():
prop = [int(siid), int(iid)]
else:
prop = [key]
miot = False
if value is None:
setp = False
elif setp:
prop.append(string_to_value(value))
props.append(prop)
if miot and argc > 0:
args = [] if arg == '[]' else [string_to_value(a) for a in argv]
return await service.miot_action(did, props[0], args)
do_props = ((service.home_get_props, service.miot_get_props), (service.home_set_props, service.miot_set_props))[setp][miot]
return await do_props(did, props if miot or setp else [p[0] for p in props])

View File

@@ -0,0 +1,197 @@
import os
import time
import base64
import hashlib
import hmac
import json
# REGIONS = ['cn', 'de', 'i2', 'ru', 'sg', 'us']
class MiIOService:
def __init__(self, account=None, region=None):
self.account = account
self.server = 'https://' + ('' if region is None or region == 'cn' else region + '.') + 'api.io.mi.com/app'
async def miio_request(self, uri, data):
def prepare_data(token, cookies):
cookies['PassportDeviceId'] = token['deviceId']
return MiIOService.sign_data(uri, data, token['xiaomiio'][0])
headers = {'User-Agent': 'iOS-14.4-6.0.103-iPhone12,3--D7744744F7AF32F0544445285880DD63E47D9BE9-8816080-84A3F44E137B71AE-iPhone', 'x-xiaomi-protocal-flag-cli': 'PROTOCAL-HTTP2'}
resp = await self.account.mi_request('xiaomiio', self.server + uri, prepare_data, headers)
if 'result' not in resp:
raise Exception(f"Error {uri}: {resp}")
return resp['result']
async def home_request(self, did, method, params):
return await self.miio_request('/home/rpc/' + did, {'id': 1, 'method': method, "accessKey": "IOS00026747c5acafc2", 'params': params})
async def home_get_props(self, did, props):
return await self.home_request(did, 'get_prop', props)
async def home_set_props(self, did, props):
return [await self.home_set_prop(did, i[0], i[1]) for i in props]
async def home_get_prop(self, did, prop):
return (await self.home_get_props(did, [prop]))[0]
async def home_set_prop(self, did, prop, value):
result = (await self.home_request(did, 'set_' + prop, value if isinstance(value, list) else [value]))[0]
return 0 if result == 'ok' else result
async def miot_request(self, cmd, params):
return await self.miio_request('/miotspec/' + cmd, {'params': params})
async def miot_get_props(self, did, iids):
params = [{'did': did, 'siid': i[0], 'piid': i[1]} for i in iids]
result = await self.miot_request('prop/get', params)
return [it.get('value') if it.get('code') == 0 else None for it in result]
async def miot_set_props(self, did, props):
params = [{'did': did, 'siid': i[0], 'piid': i[1], 'value': i[2]} for i in props]
result = await self.miot_request('prop/set', params)
return [it.get('code', -1) for it in result]
async def miot_get_prop(self, did, iid):
return (await self.miot_get_props(did, [iid]))[0]
async def miot_set_prop(self, did, iid, value):
return (await self.miot_set_props(did, [(iid[0], iid[1], value)]))[0]
async def miot_action(self, did, iid, args=[]):
result = await self.miot_request('action', {'did': did, 'siid': iid[0], 'aiid': iid[1], 'in': args})
return result.get('code', -1)
async def device_list(self, name=None, getVirtualModel=False, getHuamiDevices=0):
result = await self.miio_request('/home/device_list', {'getVirtualModel': bool(getVirtualModel), 'getHuamiDevices': int(getHuamiDevices)})
result = result['list']
return result if name == 'full' else [{'name': i['name'], 'model': i['model'], 'did': i['did'], 'token': i['token']} for i in result if not name or name in i['name']]
async def miot_spec(self, type=None, format=None):
if not type or not type.startswith('urn'):
def get_spec(all):
if not type:
return all
ret = {}
for m, t in all.items():
if type == m:
return {m: t}
elif type in m:
ret[m] = t
return ret
import tempfile
path = os.path.join(tempfile.gettempdir(), 'miservice_miot_specs.json')
try:
with open(path) as f:
result = get_spec(json.load(f))
except:
result = None
if not result:
async with self.account.session.get('http://miot-spec.org/miot-spec-v2/instances?status=all') as r:
all = {i['model']: i['type'] for i in (await r.json())['instances']}
with open(path, 'w') as f:
json.dump(all, f)
result = get_spec(all)
if len(result) != 1:
return result
type = list(result.values())[0]
url = 'http://miot-spec.org/miot-spec-v2/instance?type=' + type
async with self.account.session.get(url) as r:
result = await r.json()
def parse_desc(node):
desc = node['description']
# pos = desc.find(' ')
# if pos != -1:
# return (desc[:pos], ' # ' + desc[pos + 2:])
name = ''
for i in range(len(desc)):
d = desc[i]
if d in '-—{「[【(<《':
return (name, ' # ' + desc[i:])
name += '_' if d == ' ' else d
return (name, '')
def make_line(siid, iid, desc, comment, readable=False):
value = f"({siid}, {iid})" if format == 'python' else iid
return f" {'' if readable else '_'}{desc} = {value}{comment}\n"
if format != 'json':
STR_HEAD, STR_SRV, STR_VALUE = ('from enum import Enum\n\n', '\nclass {}(tuple, Enum):\n', '\nclass {}(int, Enum):\n') if format == 'python' else ('', '{} = {}\n', '{}\n')
text = '# Generated by https://github.com/Yonsm/MiService\n# ' + url + '\n\n' + STR_HEAD
svcs = []
vals = []
for s in result['services']:
siid = s['iid']
svc = s['description'].replace(' ', '_')
svcs.append(svc)
text += STR_SRV.format(svc, siid)
for p in s.get('properties', []):
name, comment = parse_desc(p)
access = p['access']
comment += ''.join([' # ' + k for k, v in [(p['format'], 'string'), (''.join([a[0] for a in access]), 'r')] if k and k != v])
text += make_line(siid, p['iid'], name, comment, 'read' in access)
if 'value-range' in p:
valuer = p['value-range']
length = min(3, len(valuer))
values = {['MIN', 'MAX', 'STEP'][i]: valuer[i] for i in range(length) if i != 2 or valuer[i] != 1}
elif 'value-list' in p:
values = {i['description'].replace(' ', '_') if i['description'] else str(i['value']): i['value'] for i in p['value-list']}
else:
continue
vals.append((svc + '_' + name, values))
if 'actions' in s:
text += '\n'
for a in s['actions']:
name, comment = parse_desc(a)
comment += ''.join([f" # {io}={a[io]}" for io in ['in', 'out'] if a[io]])
text += make_line(siid, a['iid'], name, comment)
text += '\n'
for name, values in vals:
text += STR_VALUE.format(name)
for k, v in values.items():
text += f" {'_' + k if k.isdigit() else k} = {v}\n"
text += '\n'
if format == 'python':
text += '\nALL_SVCS = (' + ', '.join(svcs) + ')\n'
result = text
return result
@staticmethod
def miot_decode(ssecurity, nonce, data, gzip=False):
from Crypto.Cipher import ARC4
r = ARC4.new(base64.b64decode(MiIOService.sign_nonce(ssecurity, nonce)))
r.encrypt(bytes(1024))
decrypted = r.encrypt(base64.b64decode(data))
if gzip:
try:
from io import BytesIO
from gzip import GzipFile
compressed = BytesIO()
compressed.write(decrypted)
compressed.seek(0)
decrypted = GzipFile(fileobj=compressed, mode='rb').read()
except:
pass
return json.loads(decrypted.decode())
@staticmethod
def sign_nonce(ssecurity, nonce):
m = hashlib.sha256()
m.update(base64.b64decode(ssecurity))
m.update(base64.b64decode(nonce))
return base64.b64encode(m.digest()).decode()
@staticmethod
def sign_data(uri, data, ssecurity):
if not isinstance(data, str):
data = json.dumps(data)
nonce = base64.b64encode(os.urandom(8) + int(time.time() / 60).to_bytes(4, 'big')).decode()
snonce = MiIOService.sign_nonce(ssecurity, nonce)
msg = '&'.join([uri, snonce, nonce, 'data=' + data])
sign = hmac.new(key=base64.b64decode(snonce), msg=msg.encode(), digestmod=hashlib.sha256).digest()
return {'_nonce': nonce, 'data': data, 'signature': base64.b64encode(sign).decode()}

View File

@@ -0,0 +1,50 @@
import json
from .miaccount import MiAccount, get_random
import logging
_LOGGER = logging.getLogger(__package__)
class MiNAService:
def __init__(self, account: MiAccount):
self.account = account
async def mina_request(self, uri, data=None):
requestId = 'app_ios_' + get_random(30)
if data is not None:
data['requestId'] = requestId
else:
uri += '&requestId=' + requestId
headers = {'User-Agent': 'MiHome/6.0.103 (com.xiaomi.mihome; build:6.0.103.1; iOS 14.4.0) Alamofire/6.0.103 MICO/iOSApp/appStore/6.0.103'}
return await self.account.mi_request('micoapi', 'https://api2.mina.mi.com' + uri, data, headers)
async def device_list(self, master=0):
result = await self.mina_request('/admin/v2/device_list?master=' + str(master))
return result.get('data') if result else None
async def ubus_request(self, deviceId, method, path, message):
message = json.dumps(message)
result = await self.mina_request('/remote/ubus', {'deviceId': deviceId, 'message': message, 'method': method, 'path': path})
return result and result.get('code') == 0
async def text_to_speech(self, deviceId, text):
return await self.ubus_request(deviceId, 'text_to_speech', 'mibrain', {'text': text})
async def player_set_volume(self, deviceId, volume):
return await self.ubus_request(deviceId, 'player_set_volume', 'mediaplayer', {'volume': volume, 'media': 'app_ios'})
async def send_message(self, devices, devno, message, volume=None): # -1/0/1...
result = False
for i in range(0, len(devices)):
if devno == -1 or devno != i + 1 or devices[i]['capabilities'].get('yunduantts'):
_LOGGER.debug("Send to devno=%d index=%d: %s", devno, i, message or volume)
deviceId = devices[i]['deviceID']
result = True if volume is None else await self.player_set_volume(deviceId, volume)
if result and message:
result = await self.text_to_speech(deviceId, message)
if not result:
_LOGGER.error("Send failed: %s", message or volume)
if devno != -1 or not result:
break
return result

View File

@@ -0,0 +1,63 @@
import logging
import time
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from astrbot.core.message.components import BaseMessageComponent
from .client import SimpleMiSpeakerClient
from .mispeaker_event import MiSpeakerPlatformEvent
from astrbot.core import logger
@register_platform_adapter("mispeaker", "小爱音箱")
class MiSpeakerPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
pass
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"mispeaker",
"小爱音箱",
)
async def handle_msg(self, message: AstrBotMessage):
message_event = MiSpeakerPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)
def run(self):
self.client = SimpleMiSpeakerClient(
self.config
)
async def on_event_received(abm: AstrBotMessage):
logger.info(f"on_event_received: {abm}")
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
return self._run()
async def _run(self):
await self.client.initialize()
await self.client.start_pooling()

View File

@@ -0,0 +1,30 @@
import random
import asyncio
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from .client import SimpleMiSpeakerClient
class MiSpeakerPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleMiSpeakerClient
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(message: MessageChain, user_name: str):
pass
async def send(self, message: MessageChain):
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.send(comp.text)
await super().send(message)

View File

@@ -5,7 +5,7 @@ import botpy.types.message
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Reply
from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
@@ -14,33 +14,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
'''QQ 官方 API 仅支持回复一次'''
source = self.message_obj.raw_message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
ref = None
for i in self.send_buffer.chain:
if isinstance(i, Reply):
try:
ref = self.message_obj.raw_message.message_reference
ref = botpy.types.message.Reference(
message_id=ref.message_id,
ignore_get_message_error=False
)
except BaseException as _:
pass
break
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
payload = {
'content': plain_text,
@@ -49,37 +28,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
match type(source):
case botpy.message.GroupMessage:
if ref:
payload['message_reference'] = ref
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if ref:
payload['message_reference'] = ref
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if ref:
payload['message_reference'] = ref
if image_path:
payload['file_image'] = image_path
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
case botpy.message.DirectMessage:
if ref:
payload['message_reference'] = ref
if image_path:
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(self.send_buffer)
self.send_buffer = None
await super().send(message)
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
payload = {
@@ -111,7 +80,4 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
return plain_text, image_base64, image_file_path

View File

@@ -32,10 +32,6 @@ class WebChatMessageEvent(AstrMessageEvent):
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
web_chat_back_queue.put_nowait(None)
await super().send(message)

View File

@@ -2,7 +2,6 @@ import enum
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
class ProviderType(enum.Enum):
@@ -52,7 +51,4 @@ class LLMResponse:
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
'''工具调用名称'''

View File

@@ -108,19 +108,13 @@ class FuncCall:
for f in self.func_list:
if not f.active:
continue
func_declaration = {
"name": f.name,
"description": f.description
}
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
if params.get("properties", {}):
func_declaration["parameters"] = params
tools.append(func_declaration)
tools.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
declarations["function_declarations"] = tools
return declarations

View File

@@ -1,7 +1,6 @@
import traceback
import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .provider import Provider, STTProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
@@ -14,11 +13,8 @@ class ProviderManager():
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
self.personas: List[Personality] = []
self.selected_default_persona = None
@@ -30,7 +26,7 @@ class ProviderManager():
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
begin_dialogs = []
continue
user_turn = True
for dialog in begin_dialogs:
bd_processed.append({
@@ -42,9 +38,9 @@ class ProviderManager():
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
mood_imitation_dialogs = []
continue
user_turn = True
for dialog in mood_imitation_dialogs:
for dialog in begin_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
@@ -63,24 +59,16 @@ class ProviderManager():
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
@@ -90,16 +78,12 @@ class ProviderManager():
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
changed = False
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}")
provider_cfg['id'] = new_id
changed = True
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
self.loaded_ids[provider_cfg['id']] = True
try:
@@ -119,37 +103,25 @@ class ProviderManager():
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
if changed:
try:
config.save_config()
except Exception as e:
logger.warning(f"保存配置文件失败:{e}")
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
async def initialize(self):
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
@@ -166,18 +138,6 @@ class ProviderManager():
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
@@ -207,18 +167,11 @@ class ProviderManager():
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst:
if self.provider_stt_settings.get("enable"):
if not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
def get_insts(self):
return self.provider_insts

View File

@@ -24,32 +24,9 @@ class ProviderMeta():
id: str
model: str
type: str
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class Provider(AbstractProvider):
class Provider(abc.ABC):
def __init__(
self,
provider_config: dict,
@@ -58,11 +35,14 @@ class Provider(AbstractProvider):
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
super().__init__(provider_config)
self.model_name = ""
'''当前使用的模型名称'''
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
@@ -74,10 +54,18 @@ class Provider(AbstractProvider):
if persistant_history:
# 读取历史记录
try:
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
@abc.abstractmethod
def get_current_key(self) -> str:
@@ -145,11 +133,17 @@ class Provider(AbstractProvider):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider(AbstractProvider):
class STTProvider():
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@@ -157,15 +151,19 @@ class STTProvider(AbstractProvider):
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
'''获取文本的音频,返回音频文件路径'''
raise NotImplementedError()
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获取当前使用的模型'''
return self.provider_config.get("model", "")
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)

View File

@@ -260,11 +260,10 @@ class ProviderGoogleGenAI(Provider):
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
def get_current_key(self) -> str:

View File

@@ -118,11 +118,10 @@ class LLMTunerModelLoader(Provider):
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id):
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
async def get_current_key(self):

View File

@@ -1,6 +1,6 @@
import traceback
import base64
import json
import re
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
@@ -72,21 +72,31 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def pop_record(self, session_id: str):
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出最早的一个对话
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) < 2:
return
try:
self.session_memory[session_id].pop(0)
self.session_memory[session_id].pop(0)
except IndexError:
pass
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
@@ -94,28 +104,13 @@ class ProviderOpenAIOfficial(Provider):
if tool_list:
payloads['tools'] = tool_list
completion = None
try:
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
except BaseException as e:
# 处理不支持 Function Calling 的模型
if 'does not support Function Calling' in str(e) \
or 'does not support tools' in str(e) \
or 'Function call is not supported' in str(e): # siliconcloud
del payloads['tools']
logger.debug(f"模型 {self.model_name} 不支持 tools已自动移除")
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
else:
raise e
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion}")
logger.debug(f"completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
@@ -124,8 +119,7 @@ class ProviderOpenAIOfficial(Provider):
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
return LLMResponse("assistant", completion_text, raw_completion=completion)
return LLMResponse("assistant", completion_text)
elif choice.message.tool_calls:
# tools call (function calling)
args_ls = []
@@ -136,9 +130,8 @@ class ProviderOpenAIOfficial(Provider):
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls, raw_completion=completion)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
else:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception("Internal Error")
async def text_chat(
@@ -171,29 +164,25 @@ class ProviderOpenAIOfficial(Provider):
try:
llm_response = await self._query(payloads, func_tool)
if kwargs.get("persist", True):
await self.save_history(contexts, new_record, session_id, llm_response)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 10
while retry_cnt > 0:
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
logger.warning(f"请求失败:{e}上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
await self.pop_record(session_id)
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
if kwargs.get("persist", True):
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
raise e
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
@@ -212,11 +201,10 @@ class ProviderOpenAIOfficial(Provider):
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
def get_current_key(self) -> str:

View File

@@ -1,40 +0,0 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
class ProviderOpenAITTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("openai-tts-voice", "alloy")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def get_audio(self, text: str) -> str:
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format='wav',
input=text
) as response:
with open(path, 'wb') as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path

View File

@@ -1,7 +1,3 @@
'''
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
'''
from typing import Union
import os
import json

View File

@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata, star_map
from .star import star_registry, StarMetadata
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
@@ -54,126 +54,19 @@ class Context:
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
for star in star_registry:
if star.name == star_name:
return star
def get_all_stars(self) -> List[StarMetadata]:
'''获取当前载入的所有插件 Metadata 的列表'''
return star_registry
def get_llm_tool_manager(self) -> FuncCall:
'''获取 LLM Tool Manager其用于管理注册的所有的 Function-calling tools'''
'''
获取 LLM Tool Manager
'''
return self.provider_manager.llm_tools
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
Returns:
如果没找到,会返回 False
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def deactivate_llm_tool(self, name: str) -> bool:
'''停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''获取 AstrBot 的配置。'''
return self._config
def get_db(self) -> BaseDatabase:
'''获取 AstrBot 数据库。'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
'''
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
'''
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用function-calling / tools-use添加工具。
@@ -201,7 +94,41 @@ class Context:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
Returns:
如果没找到,会返回 False
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def deactivate_llm_tool(self, name: str) -> bool:
'''停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
@@ -235,6 +162,77 @@ class Context:
))
star_handlers_registry.append(md)
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider(Chat_Completion 类型)。
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider(Chat_Completion 类型)。
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
def register_task(self, task: Awaitable, desc: str):
'''
注册一个异步任务。

View File

@@ -7,7 +7,6 @@ from .star_handler import (
register_regex,
register_permission_type,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
register_on_decorating_result,
register_after_message_sent
@@ -22,7 +21,6 @@ __all__ = [
'register_regex',
'register_permission_type',
'register_on_llm_request',
'register_on_llm_response',
'register_llm_tool',
'register_on_decorating_result',
'register_after_message_sent'

View File

@@ -139,8 +139,6 @@ def register_on_llm_request():
Examples:
```py
from astrbot.api.provider import ProviderRequest
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
@@ -154,27 +152,6 @@ def register_on_llm_request():
return decorator
def register_on_llm_response():
'''当有 LLM 请求后的事件
Examples:
```py
from astrbot.api.provider import LLMResponse
@on_llm_response()
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
...
```
请务必接收两个参数event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
return awaitable
return decorator
def register_llm_tool(name: str = None):
'''为函数调用function-calling / tools-use添加工具。

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
@@ -12,7 +11,7 @@ star_map: Dict[str, StarMetadata] = {}
@dataclass
class StarMetadata:
'''
插件的元数据。
Star 的元数据。
'''
name: str
author: str # 插件作者
@@ -21,24 +20,21 @@ class StarMetadata:
repo: str = None # 插件仓库地址
star_cls_type: type = None
'''插件的类对象的类型'''
'''Star 的类对象的类型'''
module_path: str = None
'''插件的模块路径'''
'''Star 的模块路径'''
star_cls: object = None
'''插件的类对象'''
'''Star 的类对象'''
module: ModuleType = None
'''插件的模块对象'''
'''Star 的模块对象'''
root_dir_name: str = None
'''插件的目录名'''
'''Star 的根目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留插件'''
'''是否是 AstrBot 的保留 Star'''
activated: bool = True
'''是否被激活'''
config: AstrBotConfig = None
'''插件配置'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"

View File

@@ -47,7 +47,6 @@ class EventType(enum.Enum):
'''
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后
OnDecoratingResultEvent = enum.auto() # 发送消息前
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后

View File

@@ -2,14 +2,12 @@ import inspect
import functools
import os
import sys
import json
import traceback
import yaml
import logging
from types import ModuleType
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import DEFAULT_VALUE_MAP
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
@@ -28,20 +26,13 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self
self.context._star_manager = self # 就这样吧,不想改了
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
'''存储插件的路径。即 data/plugins'''
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
'''存储插件配置的路径。data/config'''
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
'''保留插件的路径。在 packages 目录下'''
self.conf_schema_fname = "_conf_schema.json"
'''插件配置 Schema 文件名'''
def _get_classes(self, arg: ModuleType):
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers:
@@ -137,7 +128,7 @@ class PluginManager:
return metadata
async def reload(self):
'''扫描并加载所有的插件'''
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
@@ -159,13 +150,13 @@ class PluginManager:
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 导入插件模块,并尝试实例化插件
# 导入 Star 模块,并尝试实例化 Star
for plugin_module in plugin_modules:
try:
module_str = plugin_module['module']
# module_path = plugin_module['module_path']
root_dir_name = plugin_module['pname'] # 插件的目录名
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
root_dir_name = plugin_module['pname']
reserved = plugin_module.get('reserved', False)
logger.info(f"正在载入插件 {root_dir_name} ...")
@@ -182,33 +173,11 @@ class PluginManager:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
if os.path.exists(plugin_schema_path):
# 加载插件配置
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
schema=json.loads(f.read())
)
if path in star_map:
# 通过装饰器的方式注册插件
metadata = star_map[path]
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(context=self.context)
else:
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -230,20 +199,16 @@ class PluginManager:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
classes = self._get_classes(module)
try:
obj = getattr(module, classes[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
if plugin_config:
try:
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
else:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -256,7 +221,7 @@ class PluginManager:
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize() 方法
# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
@@ -327,14 +292,13 @@ class PluginManager:
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
@@ -359,9 +323,8 @@ class PluginManager:
plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
desti_dir = os.path.join(self.plugin_store_path, dir_name)
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
self.updator.unzip_file(zip_file_path, desti_dir)
# remove the zip
@@ -369,4 +332,6 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
await self.reload()
self._check_plugin_dept_update()

View File

@@ -6,8 +6,6 @@ import time
import aiohttp
import base64
import zipfile
import uuid
from typing import Union
from PIL import Image
@@ -43,21 +41,21 @@ def port_checker(port: int, host: str = "localhost"):
return False
def save_temp_img(img: Union[Image.Image, str]) -> str:
def save_temp_img(img: Image) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过 12 小时的
# 获得文件创建时间,清除超过1小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600*12:
if time.time() - ctime > 3600:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
timestamp = int(time.time())
p = f"data/temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
@@ -109,7 +107,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
'''
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, timeout=1800) as resp:
async with session.get(url, timeout=120) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get('content-length', 0))

View File

@@ -20,23 +20,18 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
return output_path
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
'''返回 duration'''
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
import pysilk
with wave.open(wav_path, 'rb') as wav:
wav_data = wav.readframes(wav.getnframes())
wav_data = BytesIO(wav_data)
output_io = BytesIO()
pysilk.encode(wav_data, output_io, 24000, 24000)
pysilk.encode(wav_data, output_io, 24000)
output_io.seek(0)
# 在首字节添加 \x02,去除结尾的\xff\xff
# 在首字节添加 \x02
silk_data = output_io.read()
silk_data_with_prefix = b'\x02' + silk_data[:-2]
silk_data_with_prefix = b'\x02' + silk_data
# return BytesIO(silk_data_with_prefix)
with open(output_path, "wb") as f:
f.write(silk_data_with_prefix)
return 0
return BytesIO(silk_data_with_prefix)

View File

@@ -39,6 +39,7 @@ class RepoZipUpdator():
else:
ret = self.github_api_release_parser(result)
except BaseException:
logger.error("解析版本信息失败")
raise Exception("解析版本信息失败")
return ret

View File

@@ -1,12 +1,14 @@
import os
import json
import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -17,9 +19,9 @@ def try_cast(value: str, type_: str):
elif type_ == "float" and isinstance(value, int):
return float(value)
def validate_config(data, schema: dict, is_core: bool):
def validate_config(data, config: AstrBotConfig):
errors = []
def validate(data, metadata=schema, path=""):
def validate(data, metadata=CONFIG_METADATA_2, path=""):
for key, meta in metadata.items():
if key not in data:
continue
@@ -54,33 +56,35 @@ def validate_config(data, schema: dict, is_core: bool):
elif meta["type"] == "object" and not isinstance(value, dict):
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
if is_core:
for key, group in schema.items():
group_meta = group.get("metadata")
if not group_meta:
continue
logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
validate(data)
return errors
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
'''验证并保存配置'''
errors = None
try:
if is_core:
errors = validate_config(post_config, CONFIG_METADATA_2, is_core)
else:
errors = validate_config(post_config, config.schema, is_core)
except BaseException as e:
logger.warning(f"验证配置时出现异常: {e}")
errors = validate_config(post_config, config)
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
def save_extension_config(post_config: dict):
if 'namespace' not in post_config:
raise ValueError("Missing key: namespace")
if 'config' not in post_config:
raise ValueError("Missing key: config")
namespace = post_config['namespace']
config: list = post_config['config'][0]['body']
for item in config:
key = item['path']
value = item['value']
typ = item['val_type']
if typ == 'int':
if not value.isdigit():
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
value = int(value)
update_config(namespace, key, value)
class ConfigRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
@@ -88,17 +92,17 @@ class ConfigRoute(Route):
self.routes = {
'/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs),
'/config/plugin/update': ('POST', self.post_extension_configs),
}
self.register_routes()
async def get_configs(self):
# plugin_name 为空时返回 AstrBot 配置
# 否则返回指定 plugin_name 的插件配置
plugin_name = request.args.get("plugin_name", None)
if not plugin_name:
# namespace 为空时返回 AstrBot 配置
# 否则返回指定 namespace 的插件配置
namespace = "" if "namespace" not in request.args else request.args["namespace"]
if not namespace:
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
return Response().ok(await self._get_extension_config(namespace)).__dict__
async def post_astrbot_configs(self):
post_configs = await request.json
@@ -106,15 +110,14 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
logger.error(e)
traceback.print_exc()
return Response().error(str(e)).__dict__
async def post_plugin_configs(self):
async def post_extension_configs(self):
post_configs = await request.json
plugin_name = request.args.get("plugin_name", "unknown")
try:
await self._save_plugin_configs(post_configs, plugin_name)
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
await self._save_extension_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
return Response().error(str(e)).__dict__
@@ -138,48 +141,28 @@ class ConfigRoute(Route):
"config": config
}
async def _get_plugin_config(self, plugin_name: str):
ret = {
"metadata": None,
"config": None
}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
if not plugin_md.config:
break
ret['config'] = plugin_md.config # 这是自定义的 Dict 类AstrBotConfig
ret['metadata'] = {
plugin_name: {
"description": f"{plugin_name} 配置",
"type": "object",
"items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema
}
}
break
return ret
async def _get_extension_config(self, namespace: str):
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
return [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
async def _save_astrbot_configs(self, post_configs: dict):
try:
save_config(post_configs, self.config, is_core=True)
save_astrbot_config(post_configs, self.config)
self.core_lifecycle.restart()
except Exception as e:
raise e
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
md = None
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
md = plugin_md
if not md:
raise ValueError(f"插件 {plugin_name} 不存在")
if not md.config:
raise ValueError(f"插件 {plugin_name} 没有注册配置")
async def _save_extension_configs(self, post_configs: dict):
try:
save_config(post_configs, md.config)
save_extension_config(post_configs)
self.core_lifecycle.restart()
except Exception as e:
raise e

View File

@@ -67,9 +67,9 @@ class PluginRoute(Route):
file = await request.files
file = file['file']
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}"
file_path = f"data/temp/{uuid.uuid4()}.zip"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:

View File

@@ -6,10 +6,6 @@ class StaticFileRoute(Route):
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat']
for i in index_:
self.app.add_url_rule(i, view_func=self.index)
@self.app.errorhandler(404)
async def page_not_found(e):
return "404 Not found。如果你初次使用打开面板发现 404请参考文档: https://astrbot.app/deploy/dashboard-404.html"
async def index(self):
return await self.app.send_static_file('index.html')

View File

@@ -1,5 +1,4 @@
import traceback
import aiohttp
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -44,7 +43,7 @@ class UpdateRoute(Route):
}
).__dict__
except Exception as e:
logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)")
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
async def update_project(self):
@@ -64,13 +63,6 @@ class UpdateRoute(Route):
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
# pip 更新依赖
logger.info("更新依赖中...")
try:
pip_installer.install(requirements_path="requirements.txt")
except Exception as e:
logger.error(f"更新依赖失败: {e}")
if reboot:
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()

View File

@@ -1,6 +0,0 @@
# What's Changed
- Gewechat 微信支持图片、语音的收和发
- 支持 OpenAI TTS文字转语音
- 支持路径映射,解决 docker 部署时两端文件系统不一致导致的富媒体文件路径不存在问题
- Napcat 下语音消息可能接收异常

View File

@@ -1,4 +0,0 @@
# What's Changed
- 修复 astrbot_updator 属性缺失与stt_enabled 未初始化 #252
- 支持消息分段回复

View File

@@ -1,8 +0,0 @@
# What's Changed
- 修复: TTS 问题
- 新增: **支持记录非唤醒状态下群聊历史记录(beta)**
- 优化: 自动删除 deepseek-r1 模型自带的 think 标签
- 优化: 自动移除 ollama 不支持 tool 的模型的 tool 请求
- 优化: /t2i 即时生效
- 优化: gewechat 消息下发异常处理

View File

@@ -1,9 +0,0 @@
# What's Changed
- 修复: 配置 Validator 不起效的问题
- 修复: DeepSeek-R1 思考标签问题
- 修复: 分段回复间隔时间不生效
- 修复: 修复白名单为空时依然终止事件 #259
- 修复: 群聊增强某些参数的类型转换问题
- 新增: 插件支持注册配置,详见 [注册插件配置](https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta)
- 优化: 插件的禁用/启用逻辑以及函数工具的禁用/启用逻辑

View File

@@ -1,6 +0,0 @@
# What's Changed
- [gewechat] [修复每次启动astrbot都需要扫码的问题](https://github.com/Soulter/AstrBot/commit/fd5d7dd37a6d74f81a148bbebef8516aa0cb5540)
- [core] [Provider 重复时不直接报错闪退](https://github.com/Soulter/AstrBot/commit/b61f9be18db9a6b8b3c5b6b36553f66dd2b79375) https://github.com/Soulter/AstrBot/issues/265
- [core] [弱化更新报错](https://github.com/Soulter/AstrBot/commit/0ba0150fd8ff2062dbe83889163888ba3e33bd49) https://github.com/Soulter/AstrBot/issues/267
- 修复 webui 无法从本地上传插件的问题

View File

@@ -1,11 +0,0 @@
# What's Changed
- [beta] 支持群聊内基于概率的主动回复
- openai tts 更换模型 #300
- 增加模型响应后的插件钩子
- 修复 相同type的provider共享了记忆
- 优化 人格情景在发现格式不对时仍然加载而不是跳过 #282
- 修复 Gemini函数调用时parameters为空对象导致的错误 by @Camreishi
- 修复 弹出记录报错的问题 #272
- 优化 移除默认人格
- 优化 未启用模型提供商时的异常处理

View File

@@ -4,8 +4,6 @@ services:
astrbot:
image: soulter/astrbot:latest
container_name: astrbot
ports:
- "6180-6200:6180-6200"
- "11451:11451"
network_mode: "host"
volumes:
- ./data:/AstrBot/data
- ./data:/AstrBot/data

View File

@@ -0,0 +1,41 @@
<script setup>
import UiParentCard from '@/components/shared/UiParentCard.vue';
const props = defineProps({
config: Array
});
</script>
<template>
<a v-show="config.length === 0">该插件没有配置</a>
<UiParentCard v-for="group in config" :key="group.name" :title="group.name" style="margin-bottom: 16px;">
<template v-for="item in group.body">
<template v-if="item.config_type === 'item'">
<template v-if="item.val_type === 'bool'">
<v-switch v-model="item.value" :label="item.name" :hint="item.description" color="primary" inset></v-switch>
</template>
<template v-else-if="item.val_type === 'str'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'int'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'list'">
<span>{{ item.name }}</span>
<v-combobox v-model="item.value" chips clearable label="请添加" multiple prepend-icon="mdi-tag-multiple-outline">
<template v-slot:selection="{ attrs, item, select, selected }">
<v-chip v-bind="attrs" :model-value="selected" closable @click="select" @click:close="remove(item)">
<strong>{{ item }}</strong>
</v-chip>
</template>
</v-combobox>
</template>
</template>
<template v-else-if="item.config_type === 'divider'">
<v-divider style="margin-top: 8px; margin-bottom: 8px;"></v-divider>
</template>
</template>
</UiParentCard>
</template>

View File

@@ -1,7 +1,7 @@
<script setup>
import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import ConfigDetailCard from '@/components/shared/ConfigDetailCard.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
@@ -52,17 +52,11 @@ import axios from 'axios';
<v-btn v-else variant="plain" disabled>已安装</v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col style="margin-bottom: 16px;" cols="12" md="12">
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
</v-col>
</v-row>
<v-dialog v-model="configDialog" width="1000">
<v-dialog v-model="configDialog" width="750">
<template v-slot:activator="{ props }">
</template>
<v-card>
@@ -71,8 +65,7 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-container>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<p v-else>这个插件没有配置</p>
<ConfigDetailCard :config="extension_config"></ConfigDetailCard>
</v-container>
</v-card-text>
<v-card-actions>
@@ -179,9 +172,9 @@ export default {
name: 'ExtensionPage',
components: {
ExtensionCard,
ConfigDetailCard,
WaitingForRestart,
ConsoleDisplayer,
AstrBotConfig
ConsoleDisplayer
},
data() {
return {
@@ -196,10 +189,7 @@ export default {
snack_success: "success",
loading_: false,
configDialog: false,
extension_config: {
"metadata": {},
"config": {}
},
extension_config: {},
upload_file: null,
pluginMarketData: {},
loadingDialog: {
@@ -374,7 +364,7 @@ export default {
openExtensionConfig(extension_name) {
this.curr_namespace = extension_name;
this.configDialog = true;
axios.get('/api/config/get?plugin_name=' + extension_name).then((res) => {
axios.get('/api/config/get?namespace=' + extension_name).then((res) => {
this.extension_config = res.data.data;
console.log(this.extension_config);
}).catch((err) => {
@@ -382,7 +372,10 @@ export default {
});
},
updateConfig() {
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
axios.post('/api/config/plugin/update', {
"config": this.extension_config,
"namespace": this.curr_namespace
}).then((res) => {
if (res.data.status === "ok") {
this.toast(res.data.message, "success");
this.$refs.wfr.check();

View File

@@ -1,127 +0,0 @@
import datetime
import uuid
import random
import astrbot.api.star as star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.platform import MessageType
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Plain, Image
from astrbot import logger
from collections import defaultdict
'''
聊天记忆增强
'''
class LongTermMemory:
def __init__(self, config: dict, context: star.Context):
self.config = config
self.context = context
self.session_chats = defaultdict(list)
"""记录群成员的群聊记录"""
try:
self.max_cnt = int(self.config["group_message_max_cnt"])
except BaseException as e:
logger.error(e)
self.max_cnt = 300
self.image_caption = self.config["image_caption"]
self.image_caption_prompt = self.config["image_caption_prompt"]
self.active_reply = self.config["active_reply"]
self.ar_method = self.active_reply["method"]
self.ar_possibility = self.active_reply["possibility_reply"]
self.ar_prompt = self.active_reply.get("prompt", "")
self.put_history_to_prompt = self.config["put_history_to_prompt"]
async def remove_session(self, event: AstrMessageEvent) -> int:
cnt = 0
if event.unified_msg_origin in self.session_chats:
cnt = len(self.session_chats[event.unified_msg_origin])
del self.session_chats[event.unified_msg_origin]
return cnt
async def get_image_caption(self, image_url: str) -> str:
provider = self.context.get_using_provider()
response = await provider.text_chat(
prompt=self.image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
return response.completion_text
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
if not self.active_reply:
return False
if event.get_message_type() != MessageType.GROUP_MESSAGE:
return False
if event.is_at_or_wake_command:
# if the message is a command, let it pass
return False
match self.ar_method:
case "possibility_reply":
trig = random.random() < self.ar_possibility
return trig
return False
async def handle_message(self, event: AstrMessageEvent):
'''仅支持群聊'''
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
for comp in event.get_messages():
if isinstance(comp, Plain):
final_message += f" {comp.text}"
elif isinstance(comp, Image):
# image_urls.append(comp.url if comp.url else comp.file)
if self.image_caption:
try:
caption = await self.get_image_caption(
comp.url if comp.url else comp.file
)
final_message += f" [Image: {caption}]"
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
else:
final_message += " [Image]"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
'''当触发 LLM 请求前,调用此方法修改 req'''
if event.unified_msg_origin not in self.session_chats:
return
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
if self.put_history_to_prompt:
prompt = req.prompt
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
req.contexts = [] # 清空上下文当使用了群聊增强所有聊天记录都在一个prompt中。
else:
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
req.system_prompt += chats_str
if self.image_caption:
req.system_prompt += (
"The images sent by the members are displayed in text form above."
)
async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats:
return
if event.get_result() and event.get_result().is_llm_result():
final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)

View File

@@ -5,17 +5,13 @@ import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import sp
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
from astrbot.api.platform import MessageType
from astrbot.api.provider import Personality, ProviderRequest
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
from collections import defaultdict
from .long_term_memory import LongTermMemory
from astrbot.core import logger
from typing import Union
@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0")
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
@@ -24,12 +20,7 @@ class Main(star.Star):
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.ltm = None
if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']:
try:
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
except BaseException as e:
logger.error(f"聊天增强 err: {e}")
self.kdb_enabled = False
async def _query_astrbot_notice(self):
try:
@@ -203,10 +194,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
async def provider(self, event: AstrMessageEvent, idx: int = None):
'''查看或者切换 LLM Provider'''
if not self.context.get_using_provider():
event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
@@ -231,27 +218,11 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
await self.context.get_using_provider().forget(message.session_id)
ret = "清除会话 LLM 聊天历史成功"
if self.ltm:
cnt = await self.ltm.remove_session(event=message)
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
message.set_result(MessageEventResult().message(ret))
message.set_result(MessageEventResult().message("重置成功"))
@filter.command("model")
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if idx_or_name is None:
models = []
try:
@@ -292,12 +263,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
size_per_page = 3
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
@@ -317,10 +282,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if index is None:
keys_data = self.context.get_using_provider().get_keys()
@@ -349,12 +310,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
l = message.message_str.split(" ")
curr_persona_name = ""
@@ -368,7 +323,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
- 人格情景列表: `/persona list`
- 人格情景详细信息: `/persona view 人格名`
- 取消人格: `/persona unset`
当前人格情景: {curr_persona_name}
@@ -394,9 +348,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
elif l[1] == "unset":
self.context.get_using_provider().curr_personality = None
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
ps = "".join(l[1:]).strip()
if persona := next(builtins.filter(
@@ -404,9 +355,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
self.context.provider_manager.personas
), None):
self.context.get_using_provider().curr_personality = persona
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
else:
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
@@ -415,6 +366,31 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await download_dashboard()
yield event.plain_result("管理面板更新完成。")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
# if provider.curr_personality['prompt']:
# req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
@@ -452,111 +428,32 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await platform.logout()
yield event.plain_result("已登出 gewechat")
return
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''群聊记忆增强'''
if self.ltm:
need_active = await self.ltm.need_active_reply(event)
group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable']
if group_icl_enable:
'''记录对话'''
try:
await self.ltm.handle_message(event)
except BaseException as e:
logger.error(e)
if need_active:
'''主动回复'''
provider = self.context.get_using_provider()
if not provider:
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
return
try:
session_provider_context = provider.session_memory.get(event.session_id)
prompt = self.ltm.ar_prompt
if not prompt:
prompt = event.message_str
yield event.request_llm(
prompt=prompt,
func_tool_manager=self.context.get_llm_tool_manager(),
session_id=event.session_id,
contexts=session_provider_context if session_provider_context else []
)
except BaseException as e:
logger.error(f"主动回复失败: {e}")
@filter.command_group("kdb")
def kdb(self):
pass
@kdb.command("on")
async def on_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = True
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if not curr_kdb_name:
yield event.plain_result("未载入任何知识库")
else:
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
@kdb.command("off")
async def off_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = False
yield event.plain_result("知识库已关闭")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
if self.ltm:
try:
await self.ltm.on_req_llm(event, req)
except BaseException as e:
logger.error(f"ltm: {e}")
@filter.after_message_sent()
async def after_llm_req(self, event: AstrMessageEvent):
'''在 LLM 请求后记录对话'''
if self.ltm:
try:
await self.ltm.after_req_llm(event)
except BaseException as e:
logger.error(f"ltm: {e}")
# @filter.command_group("kdb")
# def kdb(self):
# pass
# @kdb.command("on")
# async def on_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = True
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if not curr_kdb_name:
# yield event.plain_result("未载入任何知识库")
# else:
# yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
# @kdb.command("off")
# async def off_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = False
# yield event.plain_result("知识库已关闭")
# @filter.on_llm_request()
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if self.kdb_enabled and curr_kdb_name:
# mgr = self.context.knowledge_db_manager
# results = await mgr.retrive_records(curr_kdb_name, req.prompt)
# if results:
# req.system_prompt += "\nHere are documents that related to user's query: \n"
# for result in results:
# req.system_prompt += f"- {result}\n"
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if self.kdb_enabled and curr_kdb_name:
mgr = self.context.knowledge_db_manager
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
if results:
req.system_prompt += "\nHere are documents that related to user's query: \n"
for result in results:
req.system_prompt += f"- {result}\n"

View File

@@ -8,8 +8,7 @@ from typing import List
class Google(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.proxy = os.environ.get("https_proxy")
print(f"Google Search using proxy: {self.proxy}")
self.proxy = os.environ.get("HTTPS_PROXY")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = []

View File

@@ -22,8 +22,6 @@ class Main(star.Star):
self.sogo_search = Sogo()
self.google = Google()
self.websearch_link = self.context.get_config()['provider_settings'].get('web_search_link', False)
async def initialize(self):
websearch = self.context.get_config()['provider_settings']['web_search']
if websearch:
@@ -111,17 +109,8 @@ class Main(star.Star):
except BaseException:
site_result = ""
site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result
header = f"{idx}. {i.title} "
if self.websearch_link and i.url:
header += i.url
ret += f"{header}\n{i.snippet}\n{site_result}\n\n"
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
if self.websearch_link:
ret += "针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
return ret

View File

@@ -1,4 +1,5 @@
pydantic~=2.10.3
vchat
aiohttp
openai
qq-botpy

33
stt.py Normal file
View File

@@ -0,0 +1,33 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

View File

@@ -5,7 +5,6 @@ import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import CONFIG_METADATA_2
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType