Compare commits
1 Commits
v3.4.17
...
feat-platf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c18971f00 |
49
README.md
49
README.md
@@ -1,12 +1,14 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||
<p align="center">
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/de10f24d-cd64-433a-90b8-16c0a60de24a" width=500>
|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<h1>AstrBot</h1>
|
||||
|
||||
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
@@ -14,8 +16,9 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
</a>
|
||||
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
@@ -35,7 +38,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
> [!TIP]
|
||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM,因此无法在聊天页使用大模型。
|
||||
|
||||
## ✨ 使用方式
|
||||
|
||||
@@ -64,31 +67,19 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
|
||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| QQ 官方API | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信(企业微信) | 🚧 | 计划内 | - |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| 飞书 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
|
||||
# 🦌 接下来的路线图
|
||||
|
||||
> [!TIP]
|
||||
> 欢迎在 Issue 提出更多建议 <3
|
||||
|
||||
- [ ] 完善并保证目前所有平台适配器的功能一致性
|
||||
- [ ] 优化插件接口
|
||||
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
|
||||
- [ ] 完善“聊天增强”部分,支持持久化记忆
|
||||
- [ ] 规划 i18n
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
@@ -140,21 +131,8 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## Sponsors
|
||||
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||
3. Please ensure compliance with local laws and regulations when using this project.
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
@@ -166,6 +144,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
_アトリは、高性能ですから!_
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
register_permission_type as permission_type,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_on_llm_response as on_llm_response,
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
register_after_message_sent as after_message_sent
|
||||
@@ -32,6 +31,5 @@ __all__ = [
|
||||
'on_llm_request',
|
||||
'llm_tool',
|
||||
'on_decorating_result',
|
||||
'after_message_sent',
|
||||
'on_llm_response'
|
||||
'after_message_sent'
|
||||
]
|
||||
@@ -2,5 +2,4 @@ from astrbot.core.platform import (
|
||||
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
)
|
||||
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
from astrbot.core.message.components import *
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
@@ -1,2 +1,2 @@
|
||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
|
||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import json
|
||||
import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from .default import DEFAULT_CONFIG
|
||||
from typing import Dict
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
@@ -13,72 +13,29 @@ class RateLimitStrategy(enum.Enum):
|
||||
DISCARD = "discard"
|
||||
|
||||
class AstrBotConfig(dict):
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
|
||||
|
||||
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
|
||||
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
default_config: dict = DEFAULT_CONFIG,
|
||||
schema: dict = None
|
||||
):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
|
||||
object.__setattr__(self, 'config_path', config_path)
|
||||
object.__setattr__(self, 'default_config', default_config)
|
||||
object.__setattr__(self, 'schema', schema)
|
||||
|
||||
if schema:
|
||||
default_config = self._config_schema_to_default_config(schema)
|
||||
|
||||
if not self.check_exist():
|
||||
'''不存在时载入默认配置'''
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
has_new = self.check_config_integrity(default_config, conf)
|
||||
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
|
||||
self.update(conf)
|
||||
if has_new:
|
||||
self.save_config()
|
||||
|
||||
self.update(conf)
|
||||
|
||||
def _config_schema_to_default_config(self, schema: dict) -> dict:
|
||||
'''将 Schema 转换成 Config'''
|
||||
conf = {}
|
||||
|
||||
def _parse_schema(schema: dict, conf: dict):
|
||||
for k, v in schema.items():
|
||||
if v['type'] not in DEFAULT_VALUE_MAP:
|
||||
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
|
||||
if 'default' in v:
|
||||
default = v['default']
|
||||
else:
|
||||
default = DEFAULT_VALUE_MAP[v['type']]
|
||||
|
||||
if v['type'] == 'object':
|
||||
conf[k] = {}
|
||||
_parse_schema(v['items'], conf[k])
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
_parse_schema(schema, conf)
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
'''检查配置完整性,如果有新的配置项则返回 True'''
|
||||
has_new = False
|
||||
@@ -104,7 +61,7 @@ class AstrBotConfig(dict):
|
||||
'''
|
||||
if replace_config:
|
||||
self.update(replace_config)
|
||||
with open(self.config_path, "w", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(self, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def __getattr__(self, item):
|
||||
@@ -124,4 +81,4 @@ class AstrBotConfig(dict):
|
||||
self[key] = value
|
||||
|
||||
def check_exist(self) -> bool:
|
||||
return os.path.exists(self.config_path)
|
||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.17"
|
||||
VERSION = "3.4.11"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -24,20 +24,13 @@ DEFAULT_CONFIG = {
|
||||
"wl_ignore_admin_on_friend": True,
|
||||
"reply_with_mention": False,
|
||||
"reply_with_quote": False,
|
||||
"path_mapping": [],
|
||||
"segmented_reply": {
|
||||
"enable": False,
|
||||
"only_llm_result": True,
|
||||
"interval": "1.5,3.5",
|
||||
"regex": ".*?[。?!~…]+|.+$"
|
||||
}
|
||||
"path_mapping": []
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
"identifier": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
@@ -47,23 +40,6 @@ DEFAULT_CONFIG = {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
"possibility_reply": 0.1,
|
||||
"prompt": "",
|
||||
},
|
||||
"put_history_to_prompt": True,
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||
@@ -86,7 +62,7 @@ DEFAULT_CONFIG = {
|
||||
"persona": [
|
||||
{
|
||||
"name": "default",
|
||||
"prompt": "",
|
||||
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
@@ -119,15 +95,27 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 6199,
|
||||
},
|
||||
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"host": "localhost",
|
||||
"port": 11451,
|
||||
},
|
||||
"mispeaker(小爱音箱)": {
|
||||
"id": "mispeaker",
|
||||
"type": "mispeaker",
|
||||
"enable": False,
|
||||
"username": "",
|
||||
"password": "",
|
||||
"did": "",
|
||||
"activate_word": "测试",
|
||||
"deactivate_word": "停止",
|
||||
"interval": 1,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"id": {
|
||||
@@ -201,31 +189,6 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用分段回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"only_llm_result": {
|
||||
"description": "仅对 LLM 结果分段",
|
||||
"type": "bool",
|
||||
},
|
||||
"interval": {
|
||||
"description": "随机间隔时间(秒)",
|
||||
"type": "string",
|
||||
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
|
||||
},
|
||||
"regex": {
|
||||
"description": "正则表达式",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"reply_prefix": {
|
||||
"description": "回复前缀",
|
||||
"type": "string",
|
||||
@@ -243,9 +206,8 @@ CONFIG_METADATA_2 = {
|
||||
"id_whitelist": {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
"items": {"type": "int"},
|
||||
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
"description": "打印白名单日志",
|
||||
@@ -273,7 +235,6 @@ CONFIG_METADATA_2 = {
|
||||
"path_mapping": {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
}
|
||||
@@ -332,7 +293,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"api_base": "",
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
@@ -421,24 +382,8 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"openai_tts(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "tts-1",
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"openai-tts-voice": {
|
||||
"description": "voice",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||
},
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
@@ -469,7 +414,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
@@ -555,25 +500,16 @@ CONFIG_METADATA_2 = {
|
||||
"web_search": {
|
||||
"description": "启用网页搜索",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
},
|
||||
"web_search_link": {
|
||||
"description": "网页搜索引用链接",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
},
|
||||
"identifier": {
|
||||
"description": "启动识别群员",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"description": "启用日期时间系统提示",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||
},
|
||||
"default_personality": {
|
||||
@@ -615,15 +551,15 @@ CONFIG_METADATA_2 = {
|
||||
"begin_dialogs": {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色之后按回车",
|
||||
"items": {},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。对话需要成对(用户和助手),输入完一个角色之后按回车",
|
||||
"items": {},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
@@ -645,87 +581,6 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"description": "文本转语音(TTS)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用文本转语音(TTS)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"description": "聊天记忆增强(Beta)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious_hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "启用图像转述(需要模型支持)",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string"
|
||||
},
|
||||
"active_reply": {
|
||||
"description": "主动回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用主动回复",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。",
|
||||
},
|
||||
"method": {
|
||||
"description": "回复方法",
|
||||
"type": "string",
|
||||
"options": ["possibility_reply"],
|
||||
"hint": "回复方法。possibility_reply 为根据概率回复",
|
||||
},
|
||||
"possibility_reply": {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"obvious_hint": True,
|
||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,prompt是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"put_history_to_prompt": {
|
||||
"description": "将群聊历史记录作为 prompt",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复,建议启用,模型能够更好地完成回复任务。",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"misc_config_group": {
|
||||
@@ -735,8 +590,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "机器人唤醒前缀",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
|
||||
},
|
||||
"t2i": {
|
||||
"description": "文本转图像",
|
||||
@@ -746,7 +600,7 @@ CONFIG_METADATA_2 = {
|
||||
"admins_id": {
|
||||
"description": "管理员 ID",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"items": {"type": "int"},
|
||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
|
||||
},
|
||||
"http_proxy": {
|
||||
|
||||
@@ -7,6 +7,7 @@ from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
|
||||
@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
|
||||
proxy: T.Optional[bool] = True
|
||||
timeout: T.Optional[int] = 0
|
||||
# 额外
|
||||
path: T.Optional[str]
|
||||
path: T.Optional[str] # 用这个
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
for k in _.keys():
|
||||
@@ -306,7 +306,7 @@ class Image(BaseMessageComponent):
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
id: T.Union[str, int]
|
||||
id: int
|
||||
text: T.Optional[str] = ""
|
||||
qq: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0
|
||||
|
||||
@@ -13,10 +13,12 @@ class MessageChain():
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
'''
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
def message(self, message: str):
|
||||
'''添加一条文本消息到消息链 `chain` 中。
|
||||
@@ -75,6 +77,16 @@ class MessageChain():
|
||||
'''
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def is_split(self, is_split: bool):
|
||||
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
Note:
|
||||
具体的效果以各适配器实现为准。
|
||||
|
||||
'''
|
||||
self.is_split_ = is_split
|
||||
return self
|
||||
|
||||
class EventResultType(enum.Enum):
|
||||
'''用于描述事件处理的结果类型。
|
||||
@@ -101,6 +113,7 @@ class MessageEventResult(MessageChain):
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
`result_type` (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
|
||||
@@ -126,7 +139,7 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
|
||||
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
|
||||
'''设置事件处理的结果类型。
|
||||
|
||||
Args:
|
||||
@@ -135,15 +148,5 @@ class MessageEventResult(MessageChain):
|
||||
self.result_content_type = typ
|
||||
return self
|
||||
|
||||
def is_llm_result(self) -> bool:
|
||||
'''是否为 LLM 结果。
|
||||
'''
|
||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
def get_plain_text(self) -> str:
|
||||
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
|
||||
'''
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -21,10 +21,6 @@ class DifyRequestSubStage(Stage):
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
return
|
||||
|
||||
if provider.meta().type != "dify":
|
||||
return
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class LLMRequestSubStage(Stage):
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
if not req.prompt:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件。
|
||||
@@ -68,15 +68,6 @@ class LLMRequestSubStage(Stage):
|
||||
if _nested:
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
|
||||
# 执行 LLM 响应后的事件。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
|
||||
@@ -39,11 +39,8 @@ class StarRequestSubStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
@@ -37,11 +37,7 @@ class ProcessStage(Stage):
|
||||
# Handler 的 LLM 请求
|
||||
logger.debug(f"llm request -> {resp.prompt}")
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
@@ -53,11 +49,6 @@ class ProcessStage(Stage):
|
||||
if not event._has_send_oper and event.is_at_or_wake_command:
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
match provider.meta().type:
|
||||
case "dify":
|
||||
async for _ in self.dify_request_sub_stage.process(event):
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import random
|
||||
import asyncio
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@@ -12,19 +9,6 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
class RespondStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
|
||||
# 分段回复
|
||||
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
|
||||
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
|
||||
interval_str_ls = interval_str.replace(" ", "").split(",")
|
||||
try:
|
||||
self.interval = [float(t) for t in interval_str_ls]
|
||||
except BaseException as e:
|
||||
logger.error(f'解析分段回复的间隔时间失败。{e}')
|
||||
self.interval = [1.5, 3.5]
|
||||
logger.info(f"分段回复间隔时间:{self.interval}")
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
@@ -32,16 +16,7 @@ class RespondStage(Stage):
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event._pre_send()
|
||||
|
||||
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
await event.send(MessageChain([comp]))
|
||||
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
|
||||
else:
|
||||
await event.send(result)
|
||||
await event._post_send()
|
||||
await event.send(result)
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import time
|
||||
import re
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@@ -18,12 +16,7 @@ class ResultDecorateStage:
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
||||
|
||||
# 分段回复
|
||||
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
|
||||
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
@@ -38,53 +31,10 @@ class ResultDecorateStage:
|
||||
if len(result.chain) > 0:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
comp.text = self.reply_prefix + comp.text
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
if self.enable_segmented_reply:
|
||||
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
for seg in split_response:
|
||||
new_chain.append(Plain(seg))
|
||||
else:
|
||||
# 非 Plain 类型的消息段不分段
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
if self.use_tts and result.is_llm_result():
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(Record(file=audio_path, url=audio_path))
|
||||
else:
|
||||
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
result.chain.insert(0, Plain(self.reply_prefix))
|
||||
|
||||
# 文本转图片
|
||||
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
|
||||
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
@@ -104,7 +54,7 @@ class ResultDecorateStage:
|
||||
result.chain = [Image.fromURL(url)]
|
||||
|
||||
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
|
||||
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
|
||||
result.chain.insert(0, At(qq=event.get_sender_id()))
|
||||
|
||||
if self.reply_with_quote:
|
||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||
@@ -18,11 +18,6 @@ class WhitelistCheckStage(Stage):
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if not self.enable_whitelist_check:
|
||||
# 白名单检查未启用
|
||||
return
|
||||
|
||||
if len(self.whitelist) == 0:
|
||||
# 白名单为空,不检查
|
||||
return
|
||||
|
||||
if event.get_platform_name() == 'webchat':
|
||||
|
||||
@@ -179,15 +179,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
'''调度器会在执行 send() 前调用该方法'''
|
||||
pass
|
||||
|
||||
async def _post_send(self):
|
||||
'''调度器会在执行 send() 后调用该方法'''
|
||||
pass
|
||||
|
||||
|
||||
def set_result(self, result: Union[MessageEventResult, str]):
|
||||
'''设置消息事件的结果。
|
||||
|
||||
@@ -296,7 +287,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
def request_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager = None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
contexts: List = None,
|
||||
@@ -312,13 +302,11 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
|
||||
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||
'''
|
||||
return ProviderRequest(
|
||||
prompt = prompt,
|
||||
session_id = session_id,
|
||||
image_urls = image_urls,
|
||||
func_tool = func_tool_manager,
|
||||
contexts = contexts,
|
||||
system_prompt = system_prompt
|
||||
)
|
||||
@@ -24,12 +24,11 @@ class PlatformManager():
|
||||
case "qq_official":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
try:
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
except BaseException:
|
||||
logger.warning("当前 astrbot 已不维护 vchat 的接入,如有需要请 pip 安装 vchat 然后重启")
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
case "mispeaker":
|
||||
from .sources.mispeaker.mispeaker_adapter import MiSpeakerPlatformAdapter # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import asyncio
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
|
||||
@@ -20,18 +20,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d['type'] = 'text'
|
||||
if isinstance(segment, (Image, Record)):
|
||||
if isinstance(segment, Image):
|
||||
# convert to base64
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
bs64_data = file_to_base64(segment.file[8:])
|
||||
image_base64 = file_to_base64(segment.file[8:])
|
||||
image_file_path = segment.file[8:]
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(segment.file)
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
else:
|
||||
bs64_data = file_to_base64(segment.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
d['data'] = {
|
||||
'file': bs64_data,
|
||||
'file': image_base64,
|
||||
}
|
||||
ret.append(d)
|
||||
return ret
|
||||
@@ -40,5 +38,11 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
if message.is_split_: # 分条发送
|
||||
for m in ret:
|
||||
await self.bot.send(self.message_obj.raw_message, [m])
|
||||
await asyncio.sleep(random.uniform(0.75, 2.5))
|
||||
else:
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
await super().send(message)
|
||||
@@ -2,14 +2,10 @@ import threading
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import quart
|
||||
import base64
|
||||
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api.message_components import Plain, Image, At
|
||||
from astrbot.api import logger, sp
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
|
||||
class SimpleGewechatClient():
|
||||
'''针对 Gewechat 的简单实现。
|
||||
@@ -21,15 +17,9 @@ class SimpleGewechatClient():
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith('/'):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
|
||||
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
@@ -37,19 +27,15 @@ class SimpleGewechatClient():
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
self.callback_url = None
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
||||
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
|
||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
async def get_token_id(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -66,87 +52,57 @@ class SimpleGewechatClient():
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
d = data['Data']
|
||||
|
||||
from_user_name = d['FromUserName']['string'] # 消息来源
|
||||
d['to_wxid'] = from_user_name # 用于发信息
|
||||
msg_type = d['MsgType']
|
||||
|
||||
abm.message_id = str(d.get('MsgId'))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data['Wxid'] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d['Content']['string'] # 消息内容
|
||||
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(':\n')
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
if '\u2005' in content:
|
||||
# at
|
||||
content = content.split('\u2005')[1]
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d['MsgSource']
|
||||
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
|
||||
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
|
||||
.replace('在群聊中@了你', '') \
|
||||
.replace('在群聊中发了一段语音', '') # 真实昵称
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
# 不同消息类型
|
||||
match d['MsgType']:
|
||||
match msg_type:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
from_user_name = d['FromUserName']['string'] # 消息来源
|
||||
d['to_wxid'] = from_user_name # 用于发信息
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d['Content']['string'] # 消息内容
|
||||
user_real_name = d['PushContent'].split(' : ')[0] # 真实昵称
|
||||
user_real_name = user_real_name.replace('在群聊中@了你', '') # trick
|
||||
abm.self_id = data['Wxid'] # 机器人的 wxid
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(':\n')
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
if '\u2005' in content:
|
||||
# at
|
||||
content = content.split('\u2005')[1]
|
||||
|
||||
abm.group_id = from_user_name
|
||||
|
||||
# at
|
||||
msg_source = d['MsgSource']
|
||||
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
|
||||
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
|
||||
at_me = True
|
||||
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
abm.session_id = from_user_name
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.message = [Plain(content)]
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
abm.message_id = str(d['MsgId'])
|
||||
abm.raw_message = d
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid,
|
||||
content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
# data = await self.multimedia_downloader.download_voice(
|
||||
# self.appid,
|
||||
# content,
|
||||
# abm.message_id
|
||||
# )
|
||||
# print(data)
|
||||
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
|
||||
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
return abm
|
||||
case _:
|
||||
logger.error(f"未实现的消息类型: {d['MsgType']}")
|
||||
return
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
logger.error(f"未实现的消息类型: {msg_type}")
|
||||
|
||||
async def callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
@@ -154,43 +110,40 @@ class SimpleGewechatClient():
|
||||
if data.get('testMsg', None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = None
|
||||
try:
|
||||
abm = await self._convert(data)
|
||||
except BaseException as e:
|
||||
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。")
|
||||
|
||||
abm = await self._convert(data)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"token": self.token,
|
||||
"callbackUrl": self.callback_url
|
||||
"callbackUrl": callback_url
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
|
||||
logger.info(f"将在 {callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
|
||||
|
||||
async def start_polling(self):
|
||||
|
||||
# 设置回调
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
|
||||
|
||||
await self.server.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
@@ -233,8 +186,6 @@ class SimpleGewechatClient():
|
||||
async def login(self):
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
|
||||
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
@@ -294,7 +245,7 @@ class SimpleGewechatClient():
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||
sp.put(f"gewechat-appid-{nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
@@ -312,39 +263,4 @@ class SimpleGewechatClient():
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送语音结果: {json_blob}")
|
||||
logger.info(f"发送消息结果: {json_blob}")
|
||||
@@ -1,51 +0,0 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
class GeweDownloader():
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-GEWE-TOKEN": token
|
||||
}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {
|
||||
"appId": appid,
|
||||
"xml": xml,
|
||||
"msgId": msg_id
|
||||
}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
'''返回一个可下载的 URL'''
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {
|
||||
"appId": appid,
|
||||
"xml": xml,
|
||||
"type": choice
|
||||
}
|
||||
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
|
||||
json_blob = json.loads(data)
|
||||
if 'fileUrl' in json_blob['data']:
|
||||
return self.download_base_url + json_blob['data']['fileUrl']
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
@@ -1,24 +1,12 @@
|
||||
import wave
|
||||
import uuid
|
||||
import os
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
import random
|
||||
import asyncio
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, 'rb') as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,57 +34,5 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
img_url = comp.file
|
||||
img_path = ""
|
||||
if img_url.startswith("file:///"):
|
||||
img_path = img_url[8:]
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
img_path = await download_image_by_url(comp.file)
|
||||
else:
|
||||
img_path = img_url
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
temp_directory = os.path.abspath('data/temp')
|
||||
img_path = os.path.abspath(img_path)
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await self.client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = ""
|
||||
|
||||
if record_url.startswith("file:///"):
|
||||
record_path = record_url[8:]
|
||||
elif record_url.startswith("http"):
|
||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
||||
else:
|
||||
record_path = record_url
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
|
||||
print(f"duration: {duration}, {silk_path}")
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
# temp_directory = os.path.abspath('data/temp')
|
||||
# record_path = os.path.abspath(record_path)
|
||||
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
|
||||
# with open(record_path, "rb") as f:
|
||||
# record_path = f"data/temp/{uuid.uuid4()}.wav"
|
||||
# with open(record_path, "wb") as f2:
|
||||
# f2.write(f.read())
|
||||
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
||||
await self.client.post_voice(to_wxid, record_url, duration*1000)
|
||||
|
||||
await super().send(message)
|
||||
137
astrbot/core/platform/sources/mispeaker/client.py
Normal file
137
astrbot/core/platform/sources/mispeaker/client.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import traceback
|
||||
from .miservice import MiAccount, MiNAService, MiIOService, miio_command, miio_command_help
|
||||
from astrbot.core import logger
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At
|
||||
|
||||
class SimpleMiSpeakerClient():
|
||||
'''
|
||||
@author: Soulter
|
||||
@references: https://github.com/yihong0618/xiaogpt/blob/main/xiaogpt/xiaogpt.py
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
self.username = config['username']
|
||||
self.password = config['password']
|
||||
self.did = config['did']
|
||||
self.store = os.path.join("data", '.mi.token')
|
||||
self.interval = float(config.get('interval', 1))
|
||||
|
||||
self.conv_query_cookies = {
|
||||
'userId': '',
|
||||
'deviceId': '',
|
||||
'serviceToken': ''
|
||||
}
|
||||
|
||||
self.MI_CONVERSATION_URL = "https://userprofile.mina.mi.com/device_profile/v2/conversation?source=dialogu&hardware={hardware}×tamp={timestamp}&limit=1"
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
self.activate_word = config.get('activate_word', '测试')
|
||||
self.deactivate_word = config.get('deactivate_word', '停止')
|
||||
|
||||
self.entered = False
|
||||
|
||||
async def initialize(self):
|
||||
account = MiAccount(self.session, self.username, self.password, self.store)
|
||||
self.miio_service = MiIOService(account) # 小米设备服务
|
||||
self.mina_service = MiNAService(account) # 小爱音箱服务
|
||||
|
||||
device = await self.get_mina_device()
|
||||
|
||||
self.deviceID = device['deviceID']
|
||||
self.hardware = device['hardware']
|
||||
|
||||
with open(self.store, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.userId = data['userId']
|
||||
self.serviceToken = data['micoapi'][1]
|
||||
self.conv_query_cookies['userId'] = self.userId
|
||||
self.conv_query_cookies['deviceId'] = self.deviceID
|
||||
self.conv_query_cookies['serviceToken'] = self.serviceToken
|
||||
|
||||
logger.info(f"MiSpeakerClient initialized. Conv cookies: {self.conv_query_cookies}. Hardware: {self.hardware}")
|
||||
|
||||
async def get_mina_device(self) -> dict:
|
||||
devices = await self.mina_service.device_list()
|
||||
for device in devices:
|
||||
if device['miotDID'] == self.did:
|
||||
logger.info(f"找到设备 {device['alias']}({device['name']}) 了!")
|
||||
return device
|
||||
|
||||
async def get_conv(self) -> str:
|
||||
# 时区请确保为北京时间
|
||||
async with aiohttp.ClientSession() as session:
|
||||
session.cookie_jar.update_cookies(self.conv_query_cookies)
|
||||
query_ts = int(time.time())*1000
|
||||
logger.debug(f"Querying conversation at {query_ts}")
|
||||
async with session.get(self.MI_CONVERSATION_URL.format(hardware=self.hardware, timestamp=str(query_ts))) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob['code'] == 0:
|
||||
data = json.loads(json_blob['data'])
|
||||
records = data.get('records', None)
|
||||
for record in records:
|
||||
if record['time'] >= query_ts - self.interval*1000:
|
||||
return record['query']
|
||||
else:
|
||||
logger.error(f"Failed to get conversation: {json_blob}")
|
||||
|
||||
return None
|
||||
|
||||
async def start_pooling(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.interval)
|
||||
try:
|
||||
query = await self.get_conv()
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# is wake
|
||||
if query == self.activate_word:
|
||||
self.entered = True
|
||||
await self.stop_playing()
|
||||
await self.send("我来啦!")
|
||||
continue
|
||||
elif query == self.deactivate_word:
|
||||
self.entered = False
|
||||
await self.stop_playing()
|
||||
await self.send("再见,欢迎给个 Star。")
|
||||
continue
|
||||
if not self.entered:
|
||||
continue
|
||||
|
||||
await self.send("")
|
||||
abm = await self._convert(query)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
logger.error(e)
|
||||
|
||||
async def _convert(self, query: str):
|
||||
abm = AstrBotMessage()
|
||||
abm.message = [Plain(query)]
|
||||
abm.message_id = str(int(time.time()))
|
||||
abm.message_str = query
|
||||
abm.raw_message = query
|
||||
abm.session_id = f"{self.hardware}_{self.did}_{self.username}"
|
||||
abm.sender = MessageMember(self.username, "主人")
|
||||
abm.self_id = f"{self.hardware}_{self.did}"
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return abm
|
||||
|
||||
async def send(self, message: str):
|
||||
text = f'5 {message}'
|
||||
await miio_command(self.miio_service, self.did, text, 'astrbot')
|
||||
|
||||
async def stop_playing(self):
|
||||
text = f'3-2'
|
||||
await miio_command(self.miio_service, self.did, text, 'astrbot')
|
||||
21
astrbot/core/platform/sources/mispeaker/miservice/LICENSE
Normal file
21
astrbot/core/platform/sources/mispeaker/miservice/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021-2022 Yonsm
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
5
astrbot/core/platform/sources/mispeaker/miservice/__init__.py
Executable file
5
astrbot/core/platform/sources/mispeaker/miservice/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
||||
from .miaccount import MiAccount, MiTokenStore
|
||||
from .minaservice import MiNAService
|
||||
from .miioservice import MiIOService
|
||||
from .miiocommand import miio_command, miio_command_help
|
||||
|
||||
135
astrbot/core/platform/sources/mispeaker/miservice/miaccount.py
Normal file
135
astrbot/core/platform/sources/mispeaker/miservice/miaccount.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from urllib import parse
|
||||
from aiohttp import ClientSession
|
||||
from aiofiles import open as async_open
|
||||
|
||||
_LOGGER = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def get_random(length):
|
||||
return ''.join(random.sample(string.ascii_letters + string.digits, length))
|
||||
|
||||
|
||||
class MiTokenStore:
|
||||
|
||||
def __init__(self, token_path):
|
||||
self.token_path = token_path
|
||||
|
||||
async def load_token(self):
|
||||
if os.path.isfile(self.token_path):
|
||||
try:
|
||||
async with async_open(self.token_path) as f:
|
||||
return json.loads(await f.read())
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Exception on load token from %s: %s", self.token_path, e)
|
||||
return None
|
||||
|
||||
async def save_token(self, token=None):
|
||||
if token:
|
||||
try:
|
||||
async with async_open(self.token_path, 'w') as f:
|
||||
await f.write(json.dumps(token, indent=2))
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Exception on save token to %s: %s", self.token_path, e)
|
||||
elif os.path.isfile(self.token_path):
|
||||
os.remove(self.token_path)
|
||||
|
||||
|
||||
class MiAccount:
|
||||
|
||||
def __init__(self, session: ClientSession, username, password, token_store='.mi.token'):
|
||||
self.session = session
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.token_store = MiTokenStore(token_store) if isinstance(token_store, str) else token_store
|
||||
self.token = None
|
||||
|
||||
async def login(self, sid):
|
||||
if not self.token:
|
||||
self.token = {'deviceId': get_random(16).upper()}
|
||||
try:
|
||||
resp = await self._serviceLogin(f'serviceLogin?sid={sid}&_json=true')
|
||||
if resp['code'] != 0:
|
||||
data = {
|
||||
'_json': 'true',
|
||||
'qs': resp['qs'],
|
||||
'sid': resp['sid'],
|
||||
'_sign': resp['_sign'],
|
||||
'callback': resp['callback'],
|
||||
'user': self.username,
|
||||
'hash': hashlib.md5(self.password.encode()).hexdigest().upper()
|
||||
}
|
||||
resp = await self._serviceLogin('serviceLoginAuth2', data)
|
||||
if resp['code'] != 0:
|
||||
raise Exception(resp)
|
||||
|
||||
self.token['userId'] = resp['userId']
|
||||
self.token['passToken'] = resp['passToken']
|
||||
|
||||
serviceToken = await self._securityTokenService(resp['location'], resp['nonce'], resp['ssecurity'])
|
||||
self.token[sid] = (resp['ssecurity'], serviceToken)
|
||||
if self.token_store:
|
||||
await self.token_store.save_token(self.token)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.token = None
|
||||
if self.token_store:
|
||||
await self.token_store.save_token()
|
||||
_LOGGER.exception("Exception on login %s: %s", self.username, e)
|
||||
return False
|
||||
|
||||
async def _serviceLogin(self, uri, data=None):
|
||||
headers = {'User-Agent': 'APP/com.xiaomi.mihome APPV/6.0.103 iosPassportSDK/3.9.0 iOS/14.4 miHSTS'}
|
||||
cookies = {'sdkVersion': '3.9', 'deviceId': self.token['deviceId']}
|
||||
if 'passToken' in self.token:
|
||||
cookies['userId'] = self.token['userId']
|
||||
cookies['passToken'] = self.token['passToken']
|
||||
url = 'https://account.xiaomi.com/pass/' + uri
|
||||
async with self.session.request('GET' if data is None else 'POST', url, data=data, cookies=cookies, headers=headers) as r:
|
||||
raw = await r.read()
|
||||
resp = json.loads(raw[11:])
|
||||
_LOGGER.debug("%s: %s", uri, resp)
|
||||
return resp
|
||||
|
||||
async def _securityTokenService(self, location, nonce, ssecurity):
|
||||
nsec = 'nonce=' + str(nonce) + '&' + ssecurity
|
||||
clientSign = base64.b64encode(hashlib.sha1(nsec.encode()).digest()).decode()
|
||||
async with self.session.get(location + '&clientSign=' + parse.quote(clientSign)) as r:
|
||||
serviceToken = r.cookies['serviceToken'].value
|
||||
if not serviceToken:
|
||||
raise Exception(await r.text())
|
||||
return serviceToken
|
||||
|
||||
async def mi_request(self, sid, url, data, headers, relogin=True):
|
||||
if self.token is None and self.token_store is not None:
|
||||
self.token = await self.token_store.load_token()
|
||||
if (self.token and sid in self.token) or await self.login(sid): # Ensure login
|
||||
cookies = {'userId': self.token['userId'], 'serviceToken': self.token[sid][1]}
|
||||
content = data(self.token, cookies) if callable(data) else data
|
||||
method = 'GET' if data is None else 'POST'
|
||||
_LOGGER.debug("%s %s", url, content)
|
||||
async with self.session.request(method, url, data=content, cookies=cookies, headers=headers) as r:
|
||||
status = r.status
|
||||
if status == 200:
|
||||
resp = await r.json(content_type=None)
|
||||
code = resp['code']
|
||||
if code == 0:
|
||||
return resp
|
||||
if 'auth' in resp.get('message', '').lower():
|
||||
status = 401
|
||||
else:
|
||||
resp = await r.text()
|
||||
if status == 401 and relogin:
|
||||
_LOGGER.warn("Auth error on request %s %s, relogin...", url, resp)
|
||||
self.token = None # Auth error, reset login
|
||||
return await self.mi_request(sid, url, data, headers, False)
|
||||
else:
|
||||
resp = "Login failed"
|
||||
raise Exception(f"Error {url}: {resp}")
|
||||
104
astrbot/core/platform/sources/mispeaker/miservice/miiocommand.py
Executable file
104
astrbot/core/platform/sources/mispeaker/miservice/miiocommand.py
Executable file
@@ -0,0 +1,104 @@
|
||||
|
||||
import json
|
||||
from .miioservice import MiIOService
|
||||
|
||||
|
||||
def twins_split(string, sep, default=None):
|
||||
pos = string.find(sep)
|
||||
return (string, default) if pos == -1 else (string[0:pos], string[pos+1:])
|
||||
|
||||
|
||||
def string_to_value(string):
|
||||
if string[0] in '"\'#':
|
||||
return string[1:-1] if string[-1] in '"\'#' else string[1:]
|
||||
elif string == 'null':
|
||||
return None
|
||||
elif string == 'false':
|
||||
return False
|
||||
elif string == 'true':
|
||||
return True
|
||||
elif string.isdigit():
|
||||
return int(string)
|
||||
try:
|
||||
return float(string)
|
||||
except:
|
||||
return string
|
||||
|
||||
def miio_command_help(did=None, prefix='?'):
|
||||
quote = '' if prefix == '?' else "'"
|
||||
return f'\
|
||||
Get Props: {prefix}<siid[-piid]>[,...]\n\
|
||||
{prefix}1,1-2,1-3,1-4,2-1,2-2,3\n\
|
||||
Set Props: {prefix}<siid[-piid]=[#]value>[,...]\n\
|
||||
{prefix}2=60,2-1=#60,2-2=false,2-3="null",3=test\n\
|
||||
Do Action: {prefix}<siid[-piid]> <arg1|[]> [...] \n\
|
||||
{prefix}2 []\n\
|
||||
{prefix}5 Hello\n\
|
||||
{prefix}5-4 Hello 1\n\n\
|
||||
Call MIoT: {prefix}<cmd=prop/get|/prop/set|action> <params>\n\
|
||||
{prefix}action {quote}{{"did":"{did or "267090026"}","siid":5,"aiid":1,"in":["Hello"]}}{quote}\n\n\
|
||||
Call MiIO: {prefix}/<uri> <data>\n\
|
||||
{prefix}/home/device_list {quote}{{"getVirtualModel":false,"getHuamiDevices":1}}{quote}\n\n\
|
||||
Devs List: {prefix}list [name=full|name_keyword] [getVirtualModel=false|true] [getHuamiDevices=0|1]\n\
|
||||
{prefix}list Light true 0\n\n\
|
||||
MIoT Spec: {prefix}spec [model_keyword|type_urn] [format=text|python|json]\n\
|
||||
{prefix}spec\n\
|
||||
{prefix}spec speaker\n\
|
||||
{prefix}spec xiaomi.wifispeaker.lx04\n\
|
||||
{prefix}spec urn:miot-spec-v2:device:speaker:0000A015:xiaomi-lx04:1\n\n\
|
||||
MIoT Decode: {prefix}decode <ssecurity> <nonce> <data> [gzip]\n\
|
||||
'
|
||||
|
||||
|
||||
async def miio_command(service: MiIOService, did, text, prefix='?'):
|
||||
cmd, arg = twins_split(text, ' ')
|
||||
|
||||
if cmd.startswith('/'):
|
||||
return await service.miio_request(cmd, arg)
|
||||
|
||||
if cmd.startswith('prop') or cmd == 'action':
|
||||
return await service.miot_request(cmd, json.loads(arg) if arg else None)
|
||||
|
||||
argv = arg.split(' ') if arg else []
|
||||
argc = len(argv)
|
||||
if cmd == 'list':
|
||||
return await service.device_list(argc > 0 and argv[0], argc > 1 and string_to_value(argv[1]), argc > 2 and argv[2])
|
||||
|
||||
if cmd == 'spec':
|
||||
return await service.miot_spec(argc > 0 and argv[0], argc > 1 and argv[1])
|
||||
|
||||
if cmd == 'decode':
|
||||
return MiIOService.miot_decode(argv[0], argv[1], argv[2], argc > 3 and argv[3] == 'gzip')
|
||||
|
||||
if not did or not cmd or cmd == '?' or cmd == '?' or cmd == 'help' or cmd == '-h' or cmd == '--help':
|
||||
return miio_command_help(did, prefix)
|
||||
|
||||
if not did.isdigit():
|
||||
devices = await service.device_list(did)
|
||||
if not devices:
|
||||
return "Device not found: " + did
|
||||
did = devices[0]['did']
|
||||
|
||||
props = []
|
||||
setp = True
|
||||
miot = True
|
||||
for item in cmd.split(','):
|
||||
key, value = twins_split(item, '=')
|
||||
siid, iid = twins_split(key, '-', '1')
|
||||
if siid.isdigit() and iid.isdigit():
|
||||
prop = [int(siid), int(iid)]
|
||||
else:
|
||||
prop = [key]
|
||||
miot = False
|
||||
if value is None:
|
||||
setp = False
|
||||
elif setp:
|
||||
prop.append(string_to_value(value))
|
||||
props.append(prop)
|
||||
|
||||
if miot and argc > 0:
|
||||
args = [] if arg == '[]' else [string_to_value(a) for a in argv]
|
||||
return await service.miot_action(did, props[0], args)
|
||||
|
||||
do_props = ((service.home_get_props, service.miot_get_props), (service.home_set_props, service.miot_set_props))[setp][miot]
|
||||
return await do_props(did, props if miot or setp else [p[0] for p in props])
|
||||
197
astrbot/core/platform/sources/mispeaker/miservice/miioservice.py
Executable file
197
astrbot/core/platform/sources/mispeaker/miservice/miioservice.py
Executable file
@@ -0,0 +1,197 @@
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
|
||||
# REGIONS = ['cn', 'de', 'i2', 'ru', 'sg', 'us']
|
||||
|
||||
|
||||
class MiIOService:
|
||||
|
||||
def __init__(self, account=None, region=None):
|
||||
self.account = account
|
||||
self.server = 'https://' + ('' if region is None or region == 'cn' else region + '.') + 'api.io.mi.com/app'
|
||||
|
||||
async def miio_request(self, uri, data):
|
||||
def prepare_data(token, cookies):
|
||||
cookies['PassportDeviceId'] = token['deviceId']
|
||||
return MiIOService.sign_data(uri, data, token['xiaomiio'][0])
|
||||
headers = {'User-Agent': 'iOS-14.4-6.0.103-iPhone12,3--D7744744F7AF32F0544445285880DD63E47D9BE9-8816080-84A3F44E137B71AE-iPhone', 'x-xiaomi-protocal-flag-cli': 'PROTOCAL-HTTP2'}
|
||||
resp = await self.account.mi_request('xiaomiio', self.server + uri, prepare_data, headers)
|
||||
if 'result' not in resp:
|
||||
raise Exception(f"Error {uri}: {resp}")
|
||||
return resp['result']
|
||||
|
||||
async def home_request(self, did, method, params):
|
||||
return await self.miio_request('/home/rpc/' + did, {'id': 1, 'method': method, "accessKey": "IOS00026747c5acafc2", 'params': params})
|
||||
|
||||
async def home_get_props(self, did, props):
|
||||
return await self.home_request(did, 'get_prop', props)
|
||||
|
||||
async def home_set_props(self, did, props):
|
||||
return [await self.home_set_prop(did, i[0], i[1]) for i in props]
|
||||
|
||||
async def home_get_prop(self, did, prop):
|
||||
return (await self.home_get_props(did, [prop]))[0]
|
||||
|
||||
async def home_set_prop(self, did, prop, value):
|
||||
result = (await self.home_request(did, 'set_' + prop, value if isinstance(value, list) else [value]))[0]
|
||||
return 0 if result == 'ok' else result
|
||||
|
||||
async def miot_request(self, cmd, params):
|
||||
return await self.miio_request('/miotspec/' + cmd, {'params': params})
|
||||
|
||||
async def miot_get_props(self, did, iids):
|
||||
params = [{'did': did, 'siid': i[0], 'piid': i[1]} for i in iids]
|
||||
result = await self.miot_request('prop/get', params)
|
||||
return [it.get('value') if it.get('code') == 0 else None for it in result]
|
||||
|
||||
async def miot_set_props(self, did, props):
|
||||
params = [{'did': did, 'siid': i[0], 'piid': i[1], 'value': i[2]} for i in props]
|
||||
result = await self.miot_request('prop/set', params)
|
||||
return [it.get('code', -1) for it in result]
|
||||
|
||||
async def miot_get_prop(self, did, iid):
|
||||
return (await self.miot_get_props(did, [iid]))[0]
|
||||
|
||||
async def miot_set_prop(self, did, iid, value):
|
||||
return (await self.miot_set_props(did, [(iid[0], iid[1], value)]))[0]
|
||||
|
||||
async def miot_action(self, did, iid, args=[]):
|
||||
result = await self.miot_request('action', {'did': did, 'siid': iid[0], 'aiid': iid[1], 'in': args})
|
||||
return result.get('code', -1)
|
||||
|
||||
async def device_list(self, name=None, getVirtualModel=False, getHuamiDevices=0):
|
||||
result = await self.miio_request('/home/device_list', {'getVirtualModel': bool(getVirtualModel), 'getHuamiDevices': int(getHuamiDevices)})
|
||||
result = result['list']
|
||||
return result if name == 'full' else [{'name': i['name'], 'model': i['model'], 'did': i['did'], 'token': i['token']} for i in result if not name or name in i['name']]
|
||||
|
||||
async def miot_spec(self, type=None, format=None):
|
||||
if not type or not type.startswith('urn'):
|
||||
def get_spec(all):
|
||||
if not type:
|
||||
return all
|
||||
ret = {}
|
||||
for m, t in all.items():
|
||||
if type == m:
|
||||
return {m: t}
|
||||
elif type in m:
|
||||
ret[m] = t
|
||||
return ret
|
||||
import tempfile
|
||||
path = os.path.join(tempfile.gettempdir(), 'miservice_miot_specs.json')
|
||||
try:
|
||||
with open(path) as f:
|
||||
result = get_spec(json.load(f))
|
||||
except:
|
||||
result = None
|
||||
if not result:
|
||||
async with self.account.session.get('http://miot-spec.org/miot-spec-v2/instances?status=all') as r:
|
||||
all = {i['model']: i['type'] for i in (await r.json())['instances']}
|
||||
with open(path, 'w') as f:
|
||||
json.dump(all, f)
|
||||
result = get_spec(all)
|
||||
if len(result) != 1:
|
||||
return result
|
||||
type = list(result.values())[0]
|
||||
|
||||
url = 'http://miot-spec.org/miot-spec-v2/instance?type=' + type
|
||||
async with self.account.session.get(url) as r:
|
||||
result = await r.json()
|
||||
|
||||
def parse_desc(node):
|
||||
desc = node['description']
|
||||
# pos = desc.find(' ')
|
||||
# if pos != -1:
|
||||
# return (desc[:pos], ' # ' + desc[pos + 2:])
|
||||
name = ''
|
||||
for i in range(len(desc)):
|
||||
d = desc[i]
|
||||
if d in '-—{「[【((<《':
|
||||
return (name, ' # ' + desc[i:])
|
||||
name += '_' if d == ' ' else d
|
||||
return (name, '')
|
||||
|
||||
def make_line(siid, iid, desc, comment, readable=False):
|
||||
value = f"({siid}, {iid})" if format == 'python' else iid
|
||||
return f" {'' if readable else '_'}{desc} = {value}{comment}\n"
|
||||
|
||||
if format != 'json':
|
||||
STR_HEAD, STR_SRV, STR_VALUE = ('from enum import Enum\n\n', '\nclass {}(tuple, Enum):\n', '\nclass {}(int, Enum):\n') if format == 'python' else ('', '{} = {}\n', '{}\n')
|
||||
text = '# Generated by https://github.com/Yonsm/MiService\n# ' + url + '\n\n' + STR_HEAD
|
||||
svcs = []
|
||||
vals = []
|
||||
|
||||
for s in result['services']:
|
||||
siid = s['iid']
|
||||
svc = s['description'].replace(' ', '_')
|
||||
svcs.append(svc)
|
||||
text += STR_SRV.format(svc, siid)
|
||||
for p in s.get('properties', []):
|
||||
name, comment = parse_desc(p)
|
||||
access = p['access']
|
||||
|
||||
comment += ''.join([' # ' + k for k, v in [(p['format'], 'string'), (''.join([a[0] for a in access]), 'r')] if k and k != v])
|
||||
text += make_line(siid, p['iid'], name, comment, 'read' in access)
|
||||
if 'value-range' in p:
|
||||
valuer = p['value-range']
|
||||
length = min(3, len(valuer))
|
||||
values = {['MIN', 'MAX', 'STEP'][i]: valuer[i] for i in range(length) if i != 2 or valuer[i] != 1}
|
||||
elif 'value-list' in p:
|
||||
values = {i['description'].replace(' ', '_') if i['description'] else str(i['value']): i['value'] for i in p['value-list']}
|
||||
else:
|
||||
continue
|
||||
vals.append((svc + '_' + name, values))
|
||||
if 'actions' in s:
|
||||
text += '\n'
|
||||
for a in s['actions']:
|
||||
name, comment = parse_desc(a)
|
||||
comment += ''.join([f" # {io}={a[io]}" for io in ['in', 'out'] if a[io]])
|
||||
text += make_line(siid, a['iid'], name, comment)
|
||||
text += '\n'
|
||||
for name, values in vals:
|
||||
text += STR_VALUE.format(name)
|
||||
for k, v in values.items():
|
||||
text += f" {'_' + k if k.isdigit() else k} = {v}\n"
|
||||
text += '\n'
|
||||
if format == 'python':
|
||||
text += '\nALL_SVCS = (' + ', '.join(svcs) + ')\n'
|
||||
result = text
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def miot_decode(ssecurity, nonce, data, gzip=False):
|
||||
from Crypto.Cipher import ARC4
|
||||
r = ARC4.new(base64.b64decode(MiIOService.sign_nonce(ssecurity, nonce)))
|
||||
r.encrypt(bytes(1024))
|
||||
decrypted = r.encrypt(base64.b64decode(data))
|
||||
if gzip:
|
||||
try:
|
||||
from io import BytesIO
|
||||
from gzip import GzipFile
|
||||
compressed = BytesIO()
|
||||
compressed.write(decrypted)
|
||||
compressed.seek(0)
|
||||
decrypted = GzipFile(fileobj=compressed, mode='rb').read()
|
||||
except:
|
||||
pass
|
||||
return json.loads(decrypted.decode())
|
||||
|
||||
@staticmethod
|
||||
def sign_nonce(ssecurity, nonce):
|
||||
m = hashlib.sha256()
|
||||
m.update(base64.b64decode(ssecurity))
|
||||
m.update(base64.b64decode(nonce))
|
||||
return base64.b64encode(m.digest()).decode()
|
||||
|
||||
@staticmethod
|
||||
def sign_data(uri, data, ssecurity):
|
||||
if not isinstance(data, str):
|
||||
data = json.dumps(data)
|
||||
nonce = base64.b64encode(os.urandom(8) + int(time.time() / 60).to_bytes(4, 'big')).decode()
|
||||
snonce = MiIOService.sign_nonce(ssecurity, nonce)
|
||||
msg = '&'.join([uri, snonce, nonce, 'data=' + data])
|
||||
sign = hmac.new(key=base64.b64decode(snonce), msg=msg.encode(), digestmod=hashlib.sha256).digest()
|
||||
return {'_nonce': nonce, 'data': data, 'signature': base64.b64encode(sign).decode()}
|
||||
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
from .miaccount import MiAccount, get_random
|
||||
|
||||
import logging
|
||||
_LOGGER = logging.getLogger(__package__)
|
||||
|
||||
|
||||
class MiNAService:
|
||||
|
||||
def __init__(self, account: MiAccount):
|
||||
self.account = account
|
||||
|
||||
async def mina_request(self, uri, data=None):
|
||||
requestId = 'app_ios_' + get_random(30)
|
||||
if data is not None:
|
||||
data['requestId'] = requestId
|
||||
else:
|
||||
uri += '&requestId=' + requestId
|
||||
headers = {'User-Agent': 'MiHome/6.0.103 (com.xiaomi.mihome; build:6.0.103.1; iOS 14.4.0) Alamofire/6.0.103 MICO/iOSApp/appStore/6.0.103'}
|
||||
return await self.account.mi_request('micoapi', 'https://api2.mina.mi.com' + uri, data, headers)
|
||||
|
||||
async def device_list(self, master=0):
|
||||
result = await self.mina_request('/admin/v2/device_list?master=' + str(master))
|
||||
return result.get('data') if result else None
|
||||
|
||||
async def ubus_request(self, deviceId, method, path, message):
|
||||
message = json.dumps(message)
|
||||
result = await self.mina_request('/remote/ubus', {'deviceId': deviceId, 'message': message, 'method': method, 'path': path})
|
||||
return result and result.get('code') == 0
|
||||
|
||||
async def text_to_speech(self, deviceId, text):
|
||||
return await self.ubus_request(deviceId, 'text_to_speech', 'mibrain', {'text': text})
|
||||
|
||||
async def player_set_volume(self, deviceId, volume):
|
||||
return await self.ubus_request(deviceId, 'player_set_volume', 'mediaplayer', {'volume': volume, 'media': 'app_ios'})
|
||||
|
||||
async def send_message(self, devices, devno, message, volume=None): # -1/0/1...
|
||||
result = False
|
||||
for i in range(0, len(devices)):
|
||||
if devno == -1 or devno != i + 1 or devices[i]['capabilities'].get('yunduantts'):
|
||||
_LOGGER.debug("Send to devno=%d index=%d: %s", devno, i, message or volume)
|
||||
deviceId = devices[i]['deviceID']
|
||||
result = True if volume is None else await self.player_set_volume(deviceId, volume)
|
||||
if result and message:
|
||||
result = await self.text_to_speech(deviceId, message)
|
||||
if not result:
|
||||
_LOGGER.error("Send failed: %s", message or volume)
|
||||
if devno != -1 or not result:
|
||||
break
|
||||
return result
|
||||
63
astrbot/core/platform/sources/mispeaker/mispeaker_adapter.py
Normal file
63
astrbot/core/platform/sources/mispeaker/mispeaker_adapter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from .client import SimpleMiSpeakerClient
|
||||
from .mispeaker_event import MiSpeakerPlatformEvent
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_platform_adapter("mispeaker", "小爱音箱")
|
||||
class MiSpeakerPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
pass
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"mispeaker",
|
||||
"小爱音箱",
|
||||
)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = MiSpeakerPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def run(self):
|
||||
self.client = SimpleMiSpeakerClient(
|
||||
self.config
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
logger.info(f"on_event_received: {abm}")
|
||||
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.initialize()
|
||||
await self.client.start_pooling()
|
||||
30
astrbot/core/platform/sources/mispeaker/mispeaker_event.py
Normal file
30
astrbot/core/platform/sources/mispeaker/mispeaker_event.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import random
|
||||
import asyncio
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from .client import SimpleMiSpeakerClient
|
||||
|
||||
class MiSpeakerPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleMiSpeakerClient
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(message: MessageChain, user_name: str):
|
||||
pass
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.send(comp.text)
|
||||
|
||||
await super().send(message)
|
||||
@@ -5,7 +5,7 @@ import botpy.types.message
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Reply
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
|
||||
@@ -14,33 +14,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
self.send_buffer = None
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not self.send_buffer:
|
||||
self.send_buffer = message
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def _post_send(self):
|
||||
'''QQ 官方 API 仅支持回复一次'''
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
|
||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||
|
||||
ref = None
|
||||
for i in self.send_buffer.chain:
|
||||
if isinstance(i, Reply):
|
||||
try:
|
||||
ref = self.message_obj.raw_message.message_reference
|
||||
ref = botpy.types.message.Reference(
|
||||
message_id=ref.message_id,
|
||||
ignore_get_message_error=False
|
||||
)
|
||||
except BaseException as _:
|
||||
pass
|
||||
break
|
||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
|
||||
|
||||
payload = {
|
||||
'content': plain_text,
|
||||
@@ -49,37 +28,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
if ref:
|
||||
payload['message_reference'] = ref
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
|
||||
payload['media'] = media
|
||||
payload['msg_type'] = 7
|
||||
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
|
||||
case botpy.message.C2CMessage:
|
||||
if ref:
|
||||
payload['message_reference'] = ref
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
|
||||
payload['media'] = media
|
||||
payload['msg_type'] = 7
|
||||
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
|
||||
case botpy.message.Message:
|
||||
if ref:
|
||||
payload['message_reference'] = ref
|
||||
if image_path:
|
||||
payload['file_image'] = image_path
|
||||
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
|
||||
case botpy.message.DirectMessage:
|
||||
if ref:
|
||||
payload['message_reference'] = ref
|
||||
if image_path:
|
||||
payload['file_image'] = image_path
|
||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
await super().send(message)
|
||||
|
||||
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
|
||||
payload = {
|
||||
@@ -111,7 +80,4 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
elif i.file and i.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(i.file)
|
||||
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
|
||||
else:
|
||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
||||
image_file_path = i.file
|
||||
return plain_text, image_base64, image_file_path
|
||||
@@ -32,10 +32,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -2,7 +2,6 @@ import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
@@ -52,7 +51,4 @@ class LLMResponse:
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
'''工具调用名称'''
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
'''工具调用名称'''
|
||||
@@ -108,19 +108,13 @@ class FuncCall:
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
func_declaration = {
|
||||
"name": f.name,
|
||||
"description": f.description
|
||||
}
|
||||
|
||||
# 检查并添加非空的properties参数
|
||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||
if params.get("properties", {}):
|
||||
func_declaration["parameters"] = params
|
||||
|
||||
tools.append(func_declaration)
|
||||
|
||||
tools.append(
|
||||
{
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
)
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import traceback
|
||||
import uuid
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .provider import Provider, STTProvider, Personality
|
||||
from .entites import ProviderType
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -14,11 +13,8 @@ class ProviderManager():
|
||||
self.providers_config: List = config['provider']
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
|
||||
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
|
||||
self.persona_configs: list = config.get('persona', [])
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
@@ -30,7 +26,7 @@ class ProviderManager():
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
|
||||
begin_dialogs = []
|
||||
continue
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append({
|
||||
@@ -42,9 +38,9 @@ class ProviderManager():
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
|
||||
mood_imitation_dialogs = []
|
||||
continue
|
||||
user_turn = True
|
||||
for dialog in mood_imitation_dialogs:
|
||||
for dialog in begin_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
@@ -63,24 +59,16 @@ class ProviderManager():
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
'''加载的 Speech To Text Provider 的实例'''
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
'''加载的 Text To Speech Provider 的实例'''
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
'''当前使用的 Provider 实例'''
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
'''当前使用的 Speech To Text Provider 实例'''
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
'''当前使用的 Text To Speech Provider 实例'''
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
@@ -90,16 +78,12 @@ class ProviderManager():
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
changed = False
|
||||
for provider_cfg in self.providers_config:
|
||||
if not provider_cfg['enable']:
|
||||
continue
|
||||
|
||||
if provider_cfg['id'] in self.loaded_ids:
|
||||
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
|
||||
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}。")
|
||||
provider_cfg['id'] = new_id
|
||||
changed = True
|
||||
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
|
||||
self.loaded_ids[provider_cfg['id']] = True
|
||||
|
||||
try:
|
||||
@@ -119,37 +103,25 @@ class ProviderManager():
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
|
||||
continue
|
||||
|
||||
if changed:
|
||||
try:
|
||||
config.save_config()
|
||||
except Exception as e:
|
||||
logger.warning(f"保存配置文件失败:{e}")
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
async def initialize(self):
|
||||
for provider_config in self.providers_config:
|
||||
if not provider_config['enable']:
|
||||
continue
|
||||
if provider_config['type'] not in provider_cls_map:
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
|
||||
provider_metadata = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
try:
|
||||
@@ -166,18 +138,6 @@ class ProviderManager():
|
||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
@@ -207,18 +167,11 @@ class ProviderManager():
|
||||
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
|
||||
if stt_enabled and not self.curr_stt_provider_inst:
|
||||
if self.provider_stt_settings.get("enable"):
|
||||
if not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
|
||||
if tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -24,32 +24,9 @@ class ProviderMeta():
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
def __init__(self, provider_config: dict) -> None:
|
||||
super().__init__()
|
||||
self.model_name = ""
|
||||
self.provider_config = provider_config
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
class Provider(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
@@ -58,11 +35,14 @@ class Provider(AbstractProvider):
|
||||
db_helper: BaseDatabase = None,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.model_name = ""
|
||||
'''当前使用的模型名称'''
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_config = provider_config
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
@@ -74,10 +54,18 @@ class Provider(AbstractProvider):
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
try:
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
@@ -145,11 +133,17 @@ class Provider(AbstractProvider):
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
|
||||
class STTProvider():
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@@ -157,15 +151,19 @@ class STTProvider(AbstractProvider):
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''获取音频的文本'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_audio(self, text: str) -> str:
|
||||
'''获取文本的音频,返回音频文件路径'''
|
||||
raise NotImplementedError()
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获取当前使用的模型'''
|
||||
return self.provider_config.get("model", "")
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
@@ -260,11 +260,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
|
||||
@@ -118,11 +118,10 @@ class LLMTunerModelLoader(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -72,21 +72,31 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str):
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
弹出最早的一个对话
|
||||
弹出第一条记录
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) < 2:
|
||||
return
|
||||
|
||||
try:
|
||||
self.session_memory[session_id].pop(0)
|
||||
self.session_memory[session_id].pop(0)
|
||||
except IndexError:
|
||||
pass
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
@@ -94,28 +104,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if tool_list:
|
||||
payloads['tools'] = tool_list
|
||||
|
||||
completion = None
|
||||
try:
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
except BaseException as e:
|
||||
# 处理不支持 Function Calling 的模型
|
||||
if 'does not support Function Calling' in str(e) \
|
||||
or 'does not support tools' in str(e) \
|
||||
or 'Function call is not supported' in str(e): # siliconcloud
|
||||
del payloads['tools']
|
||||
logger.debug(f"模型 {self.model_name} 不支持 tools,已自动移除")
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"completion: {completion}")
|
||||
logger.debug(f"completion: {completion.usage}")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
@@ -124,8 +119,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
|
||||
return LLMResponse("assistant", completion_text, raw_completion=completion)
|
||||
return LLMResponse("assistant", completion_text)
|
||||
elif choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
@@ -136,9 +130,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls, raw_completion=completion)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
|
||||
else:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception("Internal Error")
|
||||
|
||||
async def text_chat(
|
||||
@@ -171,29 +164,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
if kwargs.get("persist", True):
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
await self.pop_record(session_id)
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
if kwargs.get("persist", True):
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
@@ -212,11 +201,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import uuid
|
||||
import os
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
|
||||
class ProviderOpenAITTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
self.voice = provider_config.get("openai-tts-voice", "alloy")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name,
|
||||
voice=self.voice,
|
||||
response_format='wav',
|
||||
input=text
|
||||
) as response:
|
||||
with open(path, 'wb') as f:
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
@@ -1,7 +1,3 @@
|
||||
'''
|
||||
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
|
||||
'''
|
||||
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star import star_registry, StarMetadata
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
@@ -54,126 +54,19 @@ class Context:
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
'''根据插件名获取插件的 Metadata'''
|
||||
for star in star_registry:
|
||||
if star.name == star_name:
|
||||
return star
|
||||
|
||||
def get_all_stars(self) -> List[StarMetadata]:
|
||||
'''获取当前载入的所有插件 Metadata 的列表'''
|
||||
return star_registry
|
||||
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
'''获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools'''
|
||||
'''
|
||||
获取 LLM Tool Manager
|
||||
'''
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False
|
||||
'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
'''停用一个已经注册的函数调用工具。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''获取 AstrBot 的配置。'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''获取 AstrBot 数据库。'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
'''
|
||||
获取事件队列。
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
try:
|
||||
session = MessageSesion.from_str(session)
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
'''
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
'''
|
||||
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
@@ -201,7 +94,41 @@ class Context:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False
|
||||
'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
'''停用一个已经注册的函数调用工具。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
注册一个命令。
|
||||
@@ -235,6 +162,77 @@ class Context:
|
||||
))
|
||||
star_handlers_registry.append(md)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''
|
||||
通过 ID 获取 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''
|
||||
获取所有 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的 LLM Provider(Chat_Completion 类型)。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''
|
||||
获取 AstrBot 配置信息。
|
||||
'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''
|
||||
获取 AstrBot 数据库。
|
||||
'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
'''
|
||||
获取事件队列。
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
try:
|
||||
session = MessageSesion.from_str(session)
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_task(self, task: Awaitable, desc: str):
|
||||
'''
|
||||
注册一个异步任务。
|
||||
|
||||
@@ -7,7 +7,6 @@ from .star_handler import (
|
||||
register_regex,
|
||||
register_permission_type,
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent
|
||||
@@ -22,7 +21,6 @@ __all__ = [
|
||||
'register_regex',
|
||||
'register_permission_type',
|
||||
'register_on_llm_request',
|
||||
'register_on_llm_response',
|
||||
'register_llm_tool',
|
||||
'register_on_decorating_result',
|
||||
'register_after_message_sent'
|
||||
|
||||
@@ -139,8 +139,6 @@ def register_on_llm_request():
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
|
||||
@on_llm_request()
|
||||
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
|
||||
request.system_prompt += "你是一个猫娘..."
|
||||
@@ -154,27 +152,6 @@ def register_on_llm_request():
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_response():
|
||||
'''当有 LLM 请求后的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.provider import LLMResponse
|
||||
|
||||
@on_llm_response()
|
||||
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
|
||||
...
|
||||
```
|
||||
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
star_map: Dict[str, StarMetadata] = {}
|
||||
@@ -12,7 +11,7 @@ star_map: Dict[str, StarMetadata] = {}
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
'''
|
||||
插件的元数据。
|
||||
Star 的元数据。
|
||||
'''
|
||||
name: str
|
||||
author: str # 插件作者
|
||||
@@ -21,24 +20,21 @@ class StarMetadata:
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
star_cls_type: type = None
|
||||
'''插件的类对象的类型'''
|
||||
'''Star 的类对象的类型'''
|
||||
module_path: str = None
|
||||
'''插件的模块路径'''
|
||||
'''Star 的模块路径'''
|
||||
|
||||
star_cls: object = None
|
||||
'''插件的类对象'''
|
||||
'''Star 的类对象'''
|
||||
module: ModuleType = None
|
||||
'''插件的模块对象'''
|
||||
'''Star 的模块对象'''
|
||||
root_dir_name: str = None
|
||||
'''插件的目录名称'''
|
||||
'''Star 的根目录名'''
|
||||
reserved: bool = False
|
||||
'''是否是 AstrBot 的保留插件'''
|
||||
'''是否是 AstrBot 的保留 Star'''
|
||||
|
||||
activated: bool = True
|
||||
'''是否被激活'''
|
||||
|
||||
config: AstrBotConfig = None
|
||||
'''插件配置'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
@@ -47,7 +47,6 @@ class EventType(enum.Enum):
|
||||
'''
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
@@ -2,14 +2,12 @@ import inspect
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import DEFAULT_VALUE_MAP
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
@@ -28,20 +26,13 @@ class PluginManager:
|
||||
self.updator = PluginUpdator(config['plugin_repo_mirror'])
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self
|
||||
self.context._star_manager = self # 就这样吧,不想改了
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
|
||||
'''存储插件的路径。即 data/plugins'''
|
||||
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
|
||||
'''存储插件配置的路径。data/config'''
|
||||
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
|
||||
'''保留插件的路径。在 packages 目录下'''
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
'''插件配置 Schema 文件名'''
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
|
||||
classes = []
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
for (name, _) in clsmembers:
|
||||
@@ -137,7 +128,7 @@ class PluginManager:
|
||||
return metadata
|
||||
|
||||
async def reload(self):
|
||||
'''扫描并加载所有的插件'''
|
||||
'''扫描并加载所有的 Star'''
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
@@ -159,13 +150,13 @@ class PluginManager:
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
# 导入插件模块,并尝试实例化插件类
|
||||
# 导入 Star 模块,并尝试实例化 Star 类
|
||||
for plugin_module in plugin_modules:
|
||||
try:
|
||||
module_str = plugin_module['module']
|
||||
# module_path = plugin_module['module_path']
|
||||
root_dir_name = plugin_module['pname'] # 插件的目录名
|
||||
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
root_dir_name = plugin_module['pname']
|
||||
reserved = plugin_module.get('reserved', False)
|
||||
|
||||
logger.info(f"正在载入插件 {root_dir_name} ...")
|
||||
|
||||
@@ -182,33 +173,11 @@ class PluginManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
|
||||
continue
|
||||
|
||||
# 检查 _conf_schema.json
|
||||
plugin_config = None
|
||||
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
|
||||
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
|
||||
if os.path.exists(plugin_schema_path):
|
||||
# 加载插件配置
|
||||
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
|
||||
plugin_config = AstrBotConfig(
|
||||
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
|
||||
schema=json.loads(f.read())
|
||||
)
|
||||
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
metadata = star_map[path]
|
||||
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
|
||||
except TypeError as _:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
else:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
@@ -230,20 +199,16 @@ class PluginManager:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
|
||||
classes = self._get_classes(module)
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"插件 {root_dir_name} 实例化失败。")
|
||||
raise e
|
||||
|
||||
if plugin_config:
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
|
||||
except TypeError as _:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
else:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
|
||||
metadata = None
|
||||
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
@@ -256,7 +221,7 @@ class PluginManager:
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
# 执行 initialize() 方法
|
||||
# 执行 initialize 函数
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
@@ -327,14 +292,13 @@ class PluginManager:
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
@@ -359,9 +323,8 @@ class PluginManager:
|
||||
plugin.activated = True
|
||||
|
||||
|
||||
async def install_plugin_from_file(self, zip_file_path: str):
|
||||
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
||||
desti_dir = os.path.join(self.plugin_store_path, dir_name)
|
||||
def install_plugin_from_file(self, zip_file_path: str):
|
||||
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
|
||||
self.updator.unzip_file(zip_file_path, desti_dir)
|
||||
|
||||
# remove the zip
|
||||
@@ -369,4 +332,6 @@ class PluginManager:
|
||||
os.remove(zip_file_path)
|
||||
except BaseException as e:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
await self.reload()
|
||||
|
||||
self._check_plugin_dept_update()
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ import time
|
||||
import aiohttp
|
||||
import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -43,21 +41,21 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
return False
|
||||
|
||||
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
def save_temp_img(img: Image) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
# 获得文件创建时间,清除超过1小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600*12:
|
||||
if time.time() - ctime > 3600:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
print(f"清除临时文件失败: {e}")
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
timestamp = int(time.time())
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
@@ -109,7 +107,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, timeout=1800) as resp:
|
||||
async with session.get(url, timeout=120) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
|
||||
@@ -20,23 +20,18 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
|
||||
|
||||
return output_path
|
||||
|
||||
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
'''返回 duration'''
|
||||
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
|
||||
import pysilk
|
||||
|
||||
with wave.open(wav_path, 'rb') as wav:
|
||||
wav_data = wav.readframes(wav.getnframes())
|
||||
wav_data = BytesIO(wav_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.encode(wav_data, output_io, 24000, 24000)
|
||||
pysilk.encode(wav_data, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
|
||||
# 在首字节添加 \x02,去除结尾的\xff\xff
|
||||
# 在首字节添加 \x02
|
||||
silk_data = output_io.read()
|
||||
silk_data_with_prefix = b'\x02' + silk_data[:-2]
|
||||
silk_data_with_prefix = b'\x02' + silk_data
|
||||
|
||||
# return BytesIO(silk_data_with_prefix)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(silk_data_with_prefix)
|
||||
|
||||
return 0
|
||||
return BytesIO(silk_data_with_prefix)
|
||||
@@ -39,6 +39,7 @@ class RepoZipUpdator():
|
||||
else:
|
||||
ret = self.github_api_release_parser(result)
|
||||
except BaseException:
|
||||
logger.error("解析版本信息失败")
|
||||
raise Exception("解析版本信息失败")
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
import json
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.config import update_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
@@ -17,9 +19,9 @@ def try_cast(value: str, type_: str):
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
|
||||
def validate_config(data, schema: dict, is_core: bool):
|
||||
def validate_config(data, config: AstrBotConfig):
|
||||
errors = []
|
||||
def validate(data, metadata=schema, path=""):
|
||||
def validate(data, metadata=CONFIG_METADATA_2, path=""):
|
||||
for key, meta in metadata.items():
|
||||
if key not in data:
|
||||
continue
|
||||
@@ -54,33 +56,35 @@ def validate_config(data, schema: dict, is_core: bool):
|
||||
elif meta["type"] == "object" and not isinstance(value, dict):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
|
||||
if is_core:
|
||||
for key, group in schema.items():
|
||||
group_meta = group.get("metadata")
|
||||
if not group_meta:
|
||||
continue
|
||||
logger.info(f"验证配置: 组 {key} ...")
|
||||
validate(data, group_meta, path=f"{key}.")
|
||||
else:
|
||||
validate(data, schema)
|
||||
validate(data)
|
||||
|
||||
return errors
|
||||
|
||||
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
||||
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
|
||||
'''验证并保存配置'''
|
||||
errors = None
|
||||
try:
|
||||
if is_core:
|
||||
errors = validate_config(post_config, CONFIG_METADATA_2, is_core)
|
||||
else:
|
||||
errors = validate_config(post_config, config.schema, is_core)
|
||||
except BaseException as e:
|
||||
logger.warning(f"验证配置时出现异常: {e}")
|
||||
errors = validate_config(post_config, config)
|
||||
if errors:
|
||||
raise ValueError(f"格式校验未通过: {errors}")
|
||||
config.save_config(post_config)
|
||||
|
||||
def save_extension_config(post_config: dict):
|
||||
if 'namespace' not in post_config:
|
||||
raise ValueError("Missing key: namespace")
|
||||
if 'config' not in post_config:
|
||||
raise ValueError("Missing key: config")
|
||||
|
||||
namespace = post_config['namespace']
|
||||
config: list = post_config['config'][0]['body']
|
||||
for item in config:
|
||||
key = item['path']
|
||||
value = item['value']
|
||||
typ = item['val_type']
|
||||
if typ == 'int':
|
||||
if not value.isdigit():
|
||||
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
value = int(value)
|
||||
update_config(namespace, key, value)
|
||||
|
||||
class ConfigRoute(Route):
|
||||
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
@@ -88,17 +92,17 @@ class ConfigRoute(Route):
|
||||
self.routes = {
|
||||
'/config/get': ('GET', self.get_configs),
|
||||
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
|
||||
'/config/plugin/update': ('POST', self.post_plugin_configs),
|
||||
'/config/plugin/update': ('POST', self.post_extension_configs),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def get_configs(self):
|
||||
# plugin_name 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 plugin_name 的插件配置
|
||||
plugin_name = request.args.get("plugin_name", None)
|
||||
if not plugin_name:
|
||||
# namespace 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 namespace 的插件配置
|
||||
namespace = "" if "namespace" not in request.args else request.args["namespace"]
|
||||
if not namespace:
|
||||
return Response().ok(await self._get_astrbot_config()).__dict__
|
||||
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
|
||||
return Response().ok(await self._get_extension_config(namespace)).__dict__
|
||||
|
||||
async def post_astrbot_configs(self):
|
||||
post_configs = await request.json
|
||||
@@ -106,15 +110,14 @@ class ConfigRoute(Route):
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
traceback.print_exc()
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def post_plugin_configs(self):
|
||||
async def post_extension_configs(self):
|
||||
post_configs = await request.json
|
||||
plugin_name = request.args.get("plugin_name", "unknown")
|
||||
try:
|
||||
await self._save_plugin_configs(post_configs, plugin_name)
|
||||
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
|
||||
await self._save_extension_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -138,48 +141,28 @@ class ConfigRoute(Route):
|
||||
"config": config
|
||||
}
|
||||
|
||||
async def _get_plugin_config(self, plugin_name: str):
|
||||
ret = {
|
||||
"metadata": None,
|
||||
"config": None
|
||||
}
|
||||
|
||||
for plugin_md in star_registry:
|
||||
if plugin_md.name == plugin_name:
|
||||
if not plugin_md.config:
|
||||
break
|
||||
ret['config'] = plugin_md.config # 这是自定义的 Dict 类(AstrBotConfig)
|
||||
ret['metadata'] = {
|
||||
plugin_name: {
|
||||
"description": f"{plugin_name} 配置",
|
||||
"type": "object",
|
||||
"items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
return ret
|
||||
|
||||
async def _get_extension_config(self, namespace: str):
|
||||
path = f"data/config/{namespace}.json"
|
||||
if not os.path.exists(path):
|
||||
return []
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
return [{
|
||||
"config_type": "group",
|
||||
"name": namespace + " 插件配置",
|
||||
"description": "",
|
||||
"body": list(json.load(f).values())
|
||||
},]
|
||||
|
||||
async def _save_astrbot_configs(self, post_configs: dict):
|
||||
try:
|
||||
save_config(post_configs, self.config, is_core=True)
|
||||
save_astrbot_config(post_configs, self.config)
|
||||
self.core_lifecycle.restart()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
|
||||
md = None
|
||||
for plugin_md in star_registry:
|
||||
if plugin_md.name == plugin_name:
|
||||
md = plugin_md
|
||||
|
||||
if not md:
|
||||
raise ValueError(f"插件 {plugin_name} 不存在")
|
||||
if not md.config:
|
||||
raise ValueError(f"插件 {plugin_name} 没有注册配置")
|
||||
|
||||
|
||||
async def _save_extension_configs(self, post_configs: dict):
|
||||
try:
|
||||
save_config(post_configs, md.config)
|
||||
save_extension_config(post_configs)
|
||||
self.core_lifecycle.restart()
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -67,9 +67,9 @@ class PluginRoute(Route):
|
||||
file = await request.files
|
||||
file = file['file']
|
||||
logger.info(f"正在安装用户上传的插件 {file.filename}")
|
||||
file_path = f"data/temp/{file.filename}"
|
||||
file_path = f"data/temp/{uuid.uuid4()}.zip"
|
||||
await file.save(file_path)
|
||||
await self.plugin_manager.install_plugin_from_file(file_path)
|
||||
self.plugin_manager.install_plugin_from_file(file_path)
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
|
||||
@@ -6,10 +6,6 @@ class StaticFileRoute(Route):
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat']
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
@self.app.errorhandler(404)
|
||||
async def page_not_found(e):
|
||||
return "404 Not found。如果你初次使用打开面板发现 404,请参考文档: https://astrbot.app/deploy/dashboard-404.html"
|
||||
|
||||
async def index(self):
|
||||
return await self.app.send_static_file('index.html')
|
||||
@@ -1,5 +1,4 @@
|
||||
import traceback
|
||||
import aiohttp
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -44,7 +43,7 @@ class UpdateRoute(Route):
|
||||
}
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def update_project(self):
|
||||
@@ -64,13 +63,6 @@ class UpdateRoute(Route):
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
|
||||
# pip 更新依赖
|
||||
logger.info("更新依赖中...")
|
||||
try:
|
||||
pip_installer.install(requirements_path="requirements.txt")
|
||||
except Exception as e:
|
||||
logger.error(f"更新依赖失败: {e}")
|
||||
|
||||
if reboot:
|
||||
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
- Gewechat 微信支持图片、语音的收和发
|
||||
- 支持 OpenAI TTS(文字转语音)
|
||||
- 支持路径映射,解决 docker 部署时两端文件系统不一致导致的富媒体文件路径不存在问题
|
||||
- Napcat 下语音消息可能接收异常
|
||||
@@ -1,4 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
- 修复 astrbot_updator 属性缺失与stt_enabled 未初始化 #252
|
||||
- 支持消息分段回复
|
||||
@@ -1,8 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
- 修复: TTS 问题
|
||||
- 新增: **支持记录非唤醒状态下群聊历史记录(beta)**
|
||||
- 优化: 自动删除 deepseek-r1 模型自带的 think 标签
|
||||
- 优化: 自动移除 ollama 不支持 tool 的模型的 tool 请求
|
||||
- 优化: /t2i 即时生效
|
||||
- 优化: gewechat 消息下发异常处理
|
||||
@@ -1,9 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
- 修复: 配置 Validator 不起效的问题
|
||||
- 修复: DeepSeek-R1 思考标签问题
|
||||
- 修复: 分段回复间隔时间不生效
|
||||
- 修复: 修复白名单为空时依然终止事件 #259
|
||||
- 修复: 群聊增强某些参数的类型转换问题
|
||||
- 新增: 插件支持注册配置,详见 [注册插件配置](https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta)
|
||||
- 优化: 插件的禁用/启用逻辑以及函数工具的禁用/启用逻辑
|
||||
@@ -1,6 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
- [gewechat] [修复每次启动astrbot都需要扫码的问题](https://github.com/Soulter/AstrBot/commit/fd5d7dd37a6d74f81a148bbebef8516aa0cb5540)
|
||||
- [core] [Provider 重复时不直接报错闪退](https://github.com/Soulter/AstrBot/commit/b61f9be18db9a6b8b3c5b6b36553f66dd2b79375) https://github.com/Soulter/AstrBot/issues/265
|
||||
- [core] [弱化更新报错](https://github.com/Soulter/AstrBot/commit/0ba0150fd8ff2062dbe83889163888ba3e33bd49) https://github.com/Soulter/AstrBot/issues/267
|
||||
- 修复 webui 无法从本地上传插件的问题
|
||||
@@ -1,11 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
- [beta] 支持群聊内基于概率的主动回复
|
||||
- openai tts 更换模型 #300
|
||||
- 增加模型响应后的插件钩子
|
||||
- 修复 相同type的provider共享了记忆
|
||||
- 优化 人格情景在发现格式不对时仍然加载而不是跳过 #282
|
||||
- 修复 Gemini函数调用时,parameters为空对象导致的错误 by @Camreishi
|
||||
- 修复 弹出记录报错的问题 #272
|
||||
- 优化 移除默认人格
|
||||
- 优化 未启用模型提供商时的异常处理
|
||||
@@ -4,8 +4,6 @@ services:
|
||||
astrbot:
|
||||
image: soulter/astrbot:latest
|
||||
container_name: astrbot
|
||||
ports:
|
||||
- "6180-6200:6180-6200"
|
||||
- "11451:11451"
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ./data:/AstrBot/data
|
||||
- ./data:/AstrBot/data
|
||||
41
dashboard/src/components/shared/ConfigDetailCard.vue
Normal file
41
dashboard/src/components/shared/ConfigDetailCard.vue
Normal file
@@ -0,0 +1,41 @@
|
||||
<script setup>
|
||||
import UiParentCard from '@/components/shared/UiParentCard.vue';
|
||||
|
||||
const props = defineProps({
|
||||
config: Array
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<a v-show="config.length === 0">该插件没有配置</a>
|
||||
<UiParentCard v-for="group in config" :key="group.name" :title="group.name" style="margin-bottom: 16px;">
|
||||
<template v-for="item in group.body">
|
||||
<template v-if="item.config_type === 'item'">
|
||||
<template v-if="item.val_type === 'bool'">
|
||||
<v-switch v-model="item.value" :label="item.name" :hint="item.description" color="primary" inset></v-switch>
|
||||
</template>
|
||||
<template v-else-if="item.val_type === 'str'">
|
||||
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
|
||||
variant="outlined"></v-text-field>
|
||||
</template>
|
||||
<template v-else-if="item.val_type === 'int'">
|
||||
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
|
||||
variant="outlined"></v-text-field>
|
||||
</template>
|
||||
<template v-else-if="item.val_type === 'list'">
|
||||
<span>{{ item.name }}</span>
|
||||
<v-combobox v-model="item.value" chips clearable label="请添加" multiple prepend-icon="mdi-tag-multiple-outline">
|
||||
<template v-slot:selection="{ attrs, item, select, selected }">
|
||||
<v-chip v-bind="attrs" :model-value="selected" closable @click="select" @click:close="remove(item)">
|
||||
<strong>{{ item }}</strong>
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-combobox>
|
||||
</template>
|
||||
</template>
|
||||
<template v-else-if="item.config_type === 'divider'">
|
||||
<v-divider style="margin-top: 8px; margin-bottom: 8px;"></v-divider>
|
||||
</template>
|
||||
</template>
|
||||
</UiParentCard>
|
||||
</template>
|
||||
@@ -1,7 +1,7 @@
|
||||
<script setup>
|
||||
import ExtensionCard from '@/components/shared/ExtensionCard.vue';
|
||||
import ConfigDetailCard from '@/components/shared/ConfigDetailCard.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import axios from 'axios';
|
||||
|
||||
@@ -52,17 +52,11 @@ import axios from 'axios';
|
||||
<v-btn v-else variant="plain" disabled>已安装</v-btn>
|
||||
</div>
|
||||
</ExtensionCard>
|
||||
|
||||
</v-col>
|
||||
|
||||
<v-col style="margin-bottom: 16px;" cols="12" md="12">
|
||||
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
|
||||
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
|
||||
</v-col>
|
||||
|
||||
</v-row>
|
||||
|
||||
<v-dialog v-model="configDialog" width="1000">
|
||||
<v-dialog v-model="configDialog" width="750">
|
||||
<template v-slot:activator="{ props }">
|
||||
</template>
|
||||
<v-card>
|
||||
@@ -71,8 +65,7 @@ import axios from 'axios';
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
|
||||
<p v-else>这个插件没有配置</p>
|
||||
<ConfigDetailCard :config="extension_config"></ConfigDetailCard>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
@@ -179,9 +172,9 @@ export default {
|
||||
name: 'ExtensionPage',
|
||||
components: {
|
||||
ExtensionCard,
|
||||
ConfigDetailCard,
|
||||
WaitingForRestart,
|
||||
ConsoleDisplayer,
|
||||
AstrBotConfig
|
||||
ConsoleDisplayer
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
@@ -196,10 +189,7 @@ export default {
|
||||
snack_success: "success",
|
||||
loading_: false,
|
||||
configDialog: false,
|
||||
extension_config: {
|
||||
"metadata": {},
|
||||
"config": {}
|
||||
},
|
||||
extension_config: {},
|
||||
upload_file: null,
|
||||
pluginMarketData: {},
|
||||
loadingDialog: {
|
||||
@@ -374,7 +364,7 @@ export default {
|
||||
openExtensionConfig(extension_name) {
|
||||
this.curr_namespace = extension_name;
|
||||
this.configDialog = true;
|
||||
axios.get('/api/config/get?plugin_name=' + extension_name).then((res) => {
|
||||
axios.get('/api/config/get?namespace=' + extension_name).then((res) => {
|
||||
this.extension_config = res.data.data;
|
||||
console.log(this.extension_config);
|
||||
}).catch((err) => {
|
||||
@@ -382,7 +372,10 @@ export default {
|
||||
});
|
||||
},
|
||||
updateConfig() {
|
||||
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
|
||||
axios.post('/api/config/plugin/update', {
|
||||
"config": this.extension_config,
|
||||
"namespace": this.curr_namespace
|
||||
}).then((res) => {
|
||||
if (res.data.status === "ok") {
|
||||
this.toast(res.data.message, "success");
|
||||
this.$refs.wfr.check();
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
import datetime
|
||||
import uuid
|
||||
import random
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot import logger
|
||||
from collections import defaultdict
|
||||
|
||||
'''
|
||||
聊天记忆增强
|
||||
'''
|
||||
class LongTermMemory:
|
||||
def __init__(self, config: dict, context: star.Context):
|
||||
self.config = config
|
||||
self.context = context
|
||||
self.session_chats = defaultdict(list)
|
||||
"""记录群成员的群聊记录"""
|
||||
try:
|
||||
self.max_cnt = int(self.config["group_message_max_cnt"])
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
self.max_cnt = 300
|
||||
self.image_caption = self.config["image_caption"]
|
||||
self.image_caption_prompt = self.config["image_caption_prompt"]
|
||||
|
||||
self.active_reply = self.config["active_reply"]
|
||||
self.ar_method = self.active_reply["method"]
|
||||
self.ar_possibility = self.active_reply["possibility_reply"]
|
||||
self.ar_prompt = self.active_reply.get("prompt", "")
|
||||
|
||||
self.put_history_to_prompt = self.config["put_history_to_prompt"]
|
||||
|
||||
async def remove_session(self, event: AstrMessageEvent) -> int:
|
||||
cnt = 0
|
||||
if event.unified_msg_origin in self.session_chats:
|
||||
cnt = len(self.session_chats[event.unified_msg_origin])
|
||||
del self.session_chats[event.unified_msg_origin]
|
||||
return cnt
|
||||
|
||||
async def get_image_caption(self, image_url: str) -> str:
|
||||
provider = self.context.get_using_provider()
|
||||
response = await provider.text_chat(
|
||||
prompt=self.image_caption_prompt,
|
||||
session_id=uuid.uuid4().hex,
|
||||
image_urls=[image_url],
|
||||
persist=False,
|
||||
)
|
||||
return response.completion_text
|
||||
|
||||
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
|
||||
if not self.active_reply:
|
||||
return False
|
||||
if event.get_message_type() != MessageType.GROUP_MESSAGE:
|
||||
return False
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
# if the message is a command, let it pass
|
||||
return False
|
||||
|
||||
match self.ar_method:
|
||||
case "possibility_reply":
|
||||
trig = random.random() < self.ar_possibility
|
||||
return trig
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def handle_message(self, event: AstrMessageEvent):
|
||||
'''仅支持群聊'''
|
||||
if event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
|
||||
|
||||
for comp in event.get_messages():
|
||||
if isinstance(comp, Plain):
|
||||
final_message += f" {comp.text}"
|
||||
elif isinstance(comp, Image):
|
||||
# image_urls.append(comp.url if comp.url else comp.file)
|
||||
if self.image_caption:
|
||||
try:
|
||||
caption = await self.get_image_caption(
|
||||
comp.url if comp.url else comp.file
|
||||
)
|
||||
final_message += f" [Image: {caption}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {e}")
|
||||
else:
|
||||
final_message += " [Image]"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
'''当触发 LLM 请求前,调用此方法修改 req'''
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
|
||||
|
||||
if self.put_history_to_prompt:
|
||||
prompt = req.prompt
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.contexts = [] # 清空上下文,当使用了群聊增强,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
if event.get_result() and event.get_result().is_llm_result():
|
||||
final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
@@ -5,17 +5,13 @@ import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
from collections import defaultdict
|
||||
from .long_term_memory import LongTermMemory
|
||||
from astrbot.core import logger
|
||||
|
||||
from typing import Union
|
||||
|
||||
@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0")
|
||||
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
@@ -24,12 +20,7 @@ class Main(star.Star):
|
||||
self.identifier = cfg['provider_settings']['identifier']
|
||||
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
|
||||
|
||||
self.ltm = None
|
||||
if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']:
|
||||
try:
|
||||
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"聊天增强 err: {e}")
|
||||
self.kdb_enabled = False
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
@@ -203,10 +194,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
async def provider(self, event: AstrMessageEvent, idx: int = None):
|
||||
'''查看或者切换 LLM Provider'''
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
if idx is None:
|
||||
ret = "## 当前载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
@@ -231,27 +218,11 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
await self.context.get_using_provider().forget(message.session_id)
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm:
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
message.set_result(MessageEventResult().message("重置成功"))
|
||||
|
||||
@filter.command("model")
|
||||
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
|
||||
if idx_or_name is None:
|
||||
models = []
|
||||
try:
|
||||
@@ -292,12 +263,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
size_per_page = 3
|
||||
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
|
||||
|
||||
@@ -317,10 +282,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int=None):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
if index is None:
|
||||
keys_data = self.context.get_using_provider().get_keys()
|
||||
@@ -349,12 +310,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
|
||||
l = message.message_str.split(" ")
|
||||
|
||||
curr_persona_name = "无"
|
||||
@@ -368,7 +323,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格情景列表: `/persona list`
|
||||
- 人格情景详细信息: `/persona view 人格名`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
当前人格情景: {curr_persona_name}
|
||||
|
||||
@@ -394,9 +348,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "unset":
|
||||
self.context.get_using_provider().curr_personality = None
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if persona := next(builtins.filter(
|
||||
@@ -404,9 +355,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
self.context.provider_manager.personas
|
||||
), None):
|
||||
self.context.get_using_provider().curr_personality = persona
|
||||
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
else:
|
||||
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
@@ -415,6 +366,31 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await download_dashboard()
|
||||
yield event.plain_result("管理面板更新完成。")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
provider = self.context.get_using_provider()
|
||||
if self.prompt_prefix:
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
req.prompt = user_info + req.prompt
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# if provider.curr_personality['prompt']:
|
||||
# req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
|
||||
@filter.command("set")
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
session_id = event.get_session_id()
|
||||
@@ -452,111 +428,32 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
await platform.logout()
|
||||
yield event.plain_result("已登出 gewechat")
|
||||
return
|
||||
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''群聊记忆增强'''
|
||||
if self.ltm:
|
||||
need_active = await self.ltm.need_active_reply(event)
|
||||
|
||||
group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable']
|
||||
if group_icl_enable:
|
||||
'''记录对话'''
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
|
||||
if need_active:
|
||||
'''主动回复'''
|
||||
provider = self.context.get_using_provider()
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
|
||||
prompt = self.ltm.ar_prompt
|
||||
if not prompt:
|
||||
prompt = event.message_str
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
contexts=session_provider_context if session_provider_context else []
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
pass
|
||||
|
||||
@kdb.command("on")
|
||||
async def on_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = True
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if not curr_kdb_name:
|
||||
yield event.plain_result("未载入任何知识库")
|
||||
else:
|
||||
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
|
||||
|
||||
@kdb.command("off")
|
||||
async def off_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = False
|
||||
yield event.plain_result("知识库已关闭")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
|
||||
provider = self.context.get_using_provider()
|
||||
if self.prompt_prefix:
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
req.prompt = user_info + req.prompt
|
||||
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
if self.ltm:
|
||||
try:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_llm_req(self, event: AstrMessageEvent):
|
||||
'''在 LLM 请求后记录对话'''
|
||||
if self.ltm:
|
||||
try:
|
||||
await self.ltm.after_req_llm(event)
|
||||
except BaseException as e:
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
# @filter.command_group("kdb")
|
||||
# def kdb(self):
|
||||
# pass
|
||||
|
||||
# @kdb.command("on")
|
||||
# async def on_kdb(self, event: AstrMessageEvent):
|
||||
# self.kdb_enabled = True
|
||||
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
# if not curr_kdb_name:
|
||||
# yield event.plain_result("未载入任何知识库")
|
||||
# else:
|
||||
# yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
|
||||
|
||||
# @kdb.command("off")
|
||||
# async def off_kdb(self, event: AstrMessageEvent):
|
||||
# self.kdb_enabled = False
|
||||
# yield event.plain_result("知识库已关闭")
|
||||
|
||||
# @filter.on_llm_request()
|
||||
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
# if self.kdb_enabled and curr_kdb_name:
|
||||
# mgr = self.context.knowledge_db_manager
|
||||
# results = await mgr.retrive_records(curr_kdb_name, req.prompt)
|
||||
# if results:
|
||||
# req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
# for result in results:
|
||||
# req.system_prompt += f"- {result}\n"
|
||||
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if self.kdb_enabled and curr_kdb_name:
|
||||
mgr = self.context.knowledge_db_manager
|
||||
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
|
||||
if results:
|
||||
req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
for result in results:
|
||||
req.system_prompt += f"- {result}\n"
|
||||
@@ -8,8 +8,7 @@ from typing import List
|
||||
class Google(SearchEngine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proxy = os.environ.get("https_proxy")
|
||||
print(f"Google Search using proxy: {self.proxy}")
|
||||
self.proxy = os.environ.get("HTTPS_PROXY")
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = []
|
||||
|
||||
@@ -22,8 +22,6 @@ class Main(star.Star):
|
||||
self.sogo_search = Sogo()
|
||||
self.google = Google()
|
||||
|
||||
self.websearch_link = self.context.get_config()['provider_settings'].get('web_search_link', False)
|
||||
|
||||
async def initialize(self):
|
||||
websearch = self.context.get_config()['provider_settings']['web_search']
|
||||
if websearch:
|
||||
@@ -111,17 +109,8 @@ class Main(star.Star):
|
||||
except BaseException:
|
||||
site_result = ""
|
||||
site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result
|
||||
|
||||
header = f"{idx}. {i.title} "
|
||||
|
||||
if self.websearch_link and i.url:
|
||||
header += i.url
|
||||
|
||||
ret += f"{header}\n{i.snippet}\n{site_result}\n\n"
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
|
||||
if self.websearch_link:
|
||||
ret += "针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pydantic~=2.10.3
|
||||
vchat
|
||||
aiohttp
|
||||
openai
|
||||
qq-botpy
|
||||
|
||||
33
stt.py
Normal file
33
stt.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
model_id = "openai/whisper-large-v3"
|
||||
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
batch_size=16, # batch size for inference - set based on your device
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
result = pipe(sample)
|
||||
print(result["text"])
|
||||
@@ -5,7 +5,6 @@ import asyncio
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
|
||||
Reference in New Issue
Block a user