Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
6c18971f00 feat: 初步接入米家小爱音箱 2025-01-25 02:30:00 +08:00
92 changed files with 1759 additions and 2543 deletions

View File

@@ -1,23 +1,24 @@
<p align="center">
![logo](https://github.com/user-attachments/assets/07649e07-3b8e-4feb-9aa9-bf13af4f3476)
<p align="center">
<img src="https://github.com/user-attachments/assets/de10f24d-cd64-433a-90b8-16c0a60de24a" width=500>
</p>
<div align="center">
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<h1>AstrBot</h1>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
@@ -37,7 +38,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM因此无法在聊天页使用大模型。
## ✨ 使用方式
@@ -66,31 +67,19 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ⚡ 消息平台支持情况
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
| QQ 官方API | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| 飞书 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
# 🦌 接下来的路线图
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
@@ -142,21 +131,8 @@ _✨ 内置 Web Chat在线与机器人交互 ✨_
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## Sponsors
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
@@ -168,6 +144,5 @@ _✨ 内置 Web Chat在线与机器人交互 ✨_
4. TTS
-->
_私は、高性能ですから!_
_アトリは、高性能ですから!_

View File

@@ -6,7 +6,6 @@ from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent
@@ -32,6 +31,5 @@ __all__ = [
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent',
'on_llm_response'
'after_message_sent'
]

View File

@@ -2,5 +2,4 @@ from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
from astrbot.core.platform.register import register_platform_adapter

View File

@@ -1,2 +1,2 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData

View File

@@ -2,7 +2,7 @@ import os
import json
import logging
import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from .default import DEFAULT_CONFIG
from typing import Dict
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
@@ -13,72 +13,29 @@ class RateLimitStrategy(enum.Enum):
DISCARD = "discard"
class AstrBotConfig(dict):
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema将会通过 schema 解析出 default_config此时传入的 default_config 会被忽略。
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None
):
def __init__(self):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
'''不存在时载入默认配置'''
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
with open(config_path, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(default_config, conf)
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
@@ -104,7 +61,7 @@ class AstrBotConfig(dict):
'''
if replace_config:
self.update(replace_config)
with open(self.config_path, "w", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
@@ -124,4 +81,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(self.config_path)
return os.path.exists(ASTRBOT_CONFIG_PATH)

View File

@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.20"
VERSION = "3.4.11"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -24,22 +24,13 @@ DEFAULT_CONFIG = {
"wl_ignore_admin_on_friend": True,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": [],
"segmented_reply": {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"seg_prompt": "",
"regex": ".*?[。?!~…]+|.+$"
},
"no_permission_reply": True,
"path_mapping": []
},
"provider": [],
"provider_settings": {
"enable": True,
"wake_prefix": "",
"web_search": False,
"web_search_link": False,
"identifier": False,
"datetime_system_prompt": True,
"default_personality": "default",
@@ -49,23 +40,6 @@ DEFAULT_CONFIG = {
"enable": False,
"provider_id": "",
},
"provider_tts_settings": {
"enable": False,
"provider_id": "",
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_prompt": "Please describe the image using Chinese.",
"active_reply": {
"enable": False,
"method": "possibility_reply",
"possibility_reply": 0.1,
"prompt": "",
},
"put_history_to_prompt": True,
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -85,7 +59,14 @@ DEFAULT_CONFIG = {
"pip_install_arg": "",
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [],
"persona": [
{
"name": "default",
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
],
}
@@ -111,18 +92,30 @@ CONFIG_METADATA_2 = {
"id": "default",
"type": "aiocqhttp",
"enable": False,
"ws_reverse_host": "0.0.0.0",
"ws_reverse_host": "",
"ws_reverse_port": 6199,
},
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
"gewechat(微信)": {
"id": "gwchat",
"type": "gewechat",
"enable": False,
"base_url": "http://localhost:2531",
"nickname": "soulter",
"host": "这里填写你的局域网IP或者公网服务器IP",
"host": "localhost",
"port": 11451,
},
"mispeaker(小爱音箱)": {
"id": "mispeaker",
"type": "mispeaker",
"enable": False,
"username": "",
"password": "",
"did": "",
"activate_word": "测试",
"deactivate_word": "停止",
"interval": 1,
},
},
"items": {
"id": {
@@ -196,41 +189,6 @@ CONFIG_METADATA_2 = {
},
},
},
"no_permission_reply": {
"description": "无权限回复",
"type": "bool",
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
"items": {
"enable": {
"description": "启用分段回复",
"type": "bool",
},
"only_llm_result": {
"description": "仅对 LLM 结果分段",
"type": "bool",
},
"interval": {
"description": "随机间隔时间(秒)",
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"seg_prompt": {
"description": "分段提示词辅助",
"type": "string",
"hint": "此项为空时表达不启用这个方法。此方法会调用一次LLM请求。让 LLM 在某一句话中插入一个可以用正则表达式分隔的标记来实现LLM基于情感分段。如: `请基于情感对以下文本进行分段, 并在两段之间添加`<seg>`以便我用正则匹配。` 然后将下面的正则表达式更换为`.+?<seg>`。",
},
"regex": {
"description": "正则表达式",
"type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
},
},
},
"reply_prefix": {
"description": "回复前缀",
"type": "string",
@@ -248,9 +206,8 @@ CONFIG_METADATA_2 = {
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID当一条消息没通过白名单时会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
"items": {"type": "int"},
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID当一条消息没通过白名单时会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -278,7 +235,6 @@ CONFIG_METADATA_2 = {
"path_mapping": {
"description": "路径映射",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
}
@@ -333,21 +289,10 @@ CONFIG_METADATA_2 = {
"type": "list",
"config_template": {
"openai": {
"id": "openai",
"id": "default",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.openai.com/v1",
"model_config": {
"model": "gpt-4o-mini",
},
},
"azure_openai": {
"id": "azure",
"type": "openai_chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"model_config": {
"model": "gpt-4o-mini",
@@ -403,16 +348,6 @@ CONFIG_METADATA_2 = {
"model": "glm-4-flash",
},
},
"硅基流动": {
"id": "siliconflow",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
},
},
"llmtuner": {
"id": "llmtuner_default",
"type": "llm_tuner",
@@ -447,24 +382,8 @@ CONFIG_METADATA_2 = {
"type": "openai_whisper_selfhost",
"model": "tiny",
},
"openai_tts(API)": {
"id": "openai_tts",
"type": "openai_tts_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "tts-1",
"openai-tts-voice": "alloy",
"timeout": "20",
},
},
"items": {
"openai-tts-voice": {
"description": "voice",
"type": "string",
"obvious_hint": True,
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
},
"whisper_hint": {
"description": "本地部署 Whisper 模型须知",
"type": "string",
@@ -495,7 +414,7 @@ CONFIG_METADATA_2 = {
"api_base": {
"description": "API Base URL",
"type": "string",
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现 404 报错,可以尝试在地址末尾加上 `/v1`。",
"obvious_hint": True,
},
"base_model_path": {
@@ -581,25 +500,16 @@ CONFIG_METADATA_2 = {
"web_search": {
"description": "启用网页搜索",
"type": "bool",
"obvious_hint": True,
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
},
"web_search_link": {
"description": "网页搜索引用链接",
"type": "bool",
"obvious_hint": True,
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
},
"identifier": {
"description": "启动识别群员",
"type": "bool",
"obvious_hint": True,
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
},
"datetime_system_prompt": {
"description": "启用日期时间系统提示",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
},
"default_personality": {
@@ -641,15 +551,15 @@ CONFIG_METADATA_2 = {
"begin_dialogs": {
"description": "预设对话",
"type": "list",
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
"items": {},
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
"obvious_hint": True,
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
"items": {},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
"obvious_hint": True,
},
},
@@ -671,87 +581,6 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_tts_settings": {
"description": "文本转语音(TTS)",
"type": "object",
"items": {
"enable": {
"description": "启用文本转语音(TTS)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个TTS提供商",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
"provider_ltm_settings": {
"description": "聊天记忆增强(Beta)",
"type": "object",
"items": {
"group_icl_enable": {
"description": "群聊内记录各群员对话",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
},
"group_message_max_cnt": {
"description": "群聊消息最大数量",
"type": "int",
"obvious_hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需要模型支持)",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string"
},
"active_reply": {
"description": "主动回复",
"type": "object",
"items": {
"enable": {
"description": "启用主动回复",
"type": "bool",
"obvious_hint": True,
"hint": "启用后会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
},
"method": {
"description": "回复方法",
"type": "string",
"options": ["possibility_reply"],
"hint": "回复方法。possibility_reply 为根据概率回复",
},
"possibility_reply": {
"description": "回复概率",
"type": "float",
"obvious_hint": True,
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
},
"prompt": {
"description": "提示词",
"type": "string",
"obvious_hint": True,
"hint": "提示词。当提示词为空时如果触发回复prompt是触发的消息的内容否则是提示词。此项可以和定时回复暂未实现配合使用。",
},
},
},
"put_history_to_prompt": {
"description": "将群聊历史记录作为 prompt",
"type": "bool",
"obvious_hint": True,
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复建议启用模型能够更好地完成回复任务。",
}
},
},
},
},
"misc_config_group": {
@@ -761,8 +590,7 @@ CONFIG_METADATA_2 = {
"description": "机器人唤醒前缀",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`则内置指令help等将需要通过您的唤醒前缀来触发。",
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
},
"t2i": {
"description": "文本转图像",
@@ -772,8 +600,8 @@ CONFIG_METADATA_2 = {
"admins_id": {
"description": "管理员 ID",
"type": "list",
"items": {"type": "string"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
"items": {"type": "int"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {
"description": "HTTP 代理",

View File

@@ -1,118 +0,0 @@
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
class ConversationManager():
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
'''新建对话,并将当前会话的对话转移到新对话'''
conversation_id = str(uuid.uuid4())
self.db.new_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
'''切换会话的对话'''
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
del self.session_conversations[unified_msg_origin]
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
'''获取会话当前的对话 ID'''
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
'''获取会话的对话'''
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
'''获取会话的所有对话'''
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
'''更新会话的对话'''
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history)
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
'''更新会话的对话标题'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin,
cid=conversation_id,
title=title
)
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
'''更新会话的对话 Persona ID'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin,
cid=conversation_id,
persona_id=persona_id
)
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages

View File

@@ -7,6 +7,7 @@ from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
@@ -18,15 +19,16 @@ from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = astrbot_config
self.db = db
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
if self.astrbot_config['http_proxy']:
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
@@ -43,15 +45,12 @@ class AstrBotCoreLifecycle:
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.conversation_manager = ConversationManager(self.db)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)

View File

@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -79,35 +79,25 @@ class BaseDatabase(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
'''通过 user_id 和 cid 获取 Conversation'''
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def new_conversation(self, user_id: str, cid: str):
'''新建 Conversation'''
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def get_conversations(self, user_id: str) -> List[Conversation]:
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
raise NotImplementedError
@abc.abstractmethod
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新 Conversation'''
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_conversation(self, user_id: str, cid: str):
'''删除 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
'''更新 Conversation 标题'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
'''更新 Conversation Persona ID'''
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
raise NotImplementedError

View File

@@ -33,16 +33,16 @@ class Stats():
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
'''LLM 聊天时持久化的信息'''
@dataclass
class LLMHistory():
'''LLM 聊天时持久化的信息'''
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision():
'''Deprecated'''
id: str
url_or_path: str
caption: str
@@ -53,18 +53,13 @@ class ATRIVision():
sender_nickname: str
timestamp: int = -1
@dataclass
class Conversation():
'''LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
'''
@dataclass
class WebChatConversation():
user_id: str
cid: str
history: str = ""
'''字符串格式的列表。'''
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""

View File

@@ -6,7 +6,7 @@ from astrbot.core.db.po import (
Stats,
LLMHistory,
ATRIVision,
Conversation
WebChatConversation
)
from . import BaseDatabase
from typing import Tuple
@@ -25,37 +25,6 @@ class SQLiteDatabase(BaseDatabase):
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
'''
PRAGMA table_info(webchat_conversation)
'''
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
'''
)
self.conn.commit()
if not has_persona_id:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
'''
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
@@ -233,7 +202,7 @@ class SQLiteDatabase(BaseDatabase):
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -247,9 +216,9 @@ class SQLiteDatabase(BaseDatabase):
res = c.fetchone()
c.close()
return Conversation(*res)
return WebChatConversation(*res)
def new_conversation(self, user_id: str, cid: str):
def webchat_new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
@@ -259,7 +228,7 @@ class SQLiteDatabase(BaseDatabase):
''', (user_id, cid, history, updated_at, created_at)
)
def get_conversations(self, user_id: str) -> Tuple:
def get_webchat_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -267,7 +236,7 @@ class SQLiteDatabase(BaseDatabase):
c.execute(
'''
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
@@ -278,42 +247,24 @@ class SQLiteDatabase(BaseDatabase):
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新对话,并且同时更新时间'''
updated_at = int(time.time())
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
''', (history, updated_at, user_id, cid)
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
''', (title, user_id, cid)
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
''', (persona_id, user_id, cid)
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
)
def delete_conversation(self, user_id: str, cid: str):
def delete_webchat_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())
keywords = ",".join(vision.keywords)

View File

@@ -42,7 +42,5 @@ CREATE TABLE IF NOT EXISTS webchat_conversation(
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
updated_at INTEGER
);

View File

@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str]
path: T.Optional[str] # 用这个
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
@@ -306,7 +306,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: T.Union[str, int]
id: int
text: T.Optional[str] = ""
qq: T.Optional[int] = 0
time: T.Optional[int] = 0

View File

@@ -13,10 +13,12 @@ class MessageChain():
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''添加一条文本消息到消息链 `chain` 中。
@@ -75,6 +77,16 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''设置是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
@@ -101,6 +113,7 @@ class MessageEventResult(MessageChain):
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
@@ -126,7 +139,7 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
@@ -135,15 +148,5 @@ class MessageEventResult(MessageChain):
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult

View File

@@ -2,7 +2,6 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
@@ -12,7 +11,7 @@ from .respond.stage import RespondStage
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitStage", # 检查会话是否超过频率限制
"RateLimitCheckStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
@@ -23,7 +22,6 @@ STAGES_ORDER = [
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",

View File

@@ -21,10 +21,6 @@ class DifyRequestSubStage(Stage):
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
return
if provider.meta().type != "dify":
return

View File

@@ -2,7 +2,6 @@
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
@@ -11,7 +10,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
@@ -25,8 +24,6 @@ class LLMRequestSubStage(Stage):
if self.provider_wake_prefix.startswith(bwp):
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
@@ -49,19 +46,12 @@ class LLMRequestSubStage(Stage):
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = conversation_id
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
req.session_id = event.session_id
event.set_extra("provider_request", req)
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt and not req.image_urls:
if not req.prompt:
return
# 执行请求 LLM 前事件。
@@ -72,35 +62,18 @@ class LLMRequestSubStage(Stage):
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
try:
logger.debug(f"提供商请求 Payload: {req}")
logger.debug(f"提供商请求 Payload: {req.__dict__}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'err':
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
@@ -133,24 +106,4 @@ class LLMRequestSubStage(Stage):
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {
"role": "user",
"content": req.prompt
}
contexts.append(new_record)
contexts.append({
"role": "assistant",
"content": llm_response.completion_text
})
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.session_id,
history=contexts_to_save
)
return

View File

@@ -39,11 +39,8 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()

View File

@@ -37,11 +37,7 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
logger.debug(f"llm request -> {resp.prompt}")
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
_t = True
yield
if not _t:
yield
else:
yield
@@ -53,11 +49,6 @@ class ProcessStage(Stage):
if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
match provider.meta().type:
case "dify":
async for _ in self.dify_request_sub_stage.process(event):

View File

@@ -61,12 +61,11 @@ class RateLimitStage(Stage):
stall_duration = (next_window_time - now).total_seconds()
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
case RateLimitStrategy.STALL:
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
await asyncio.sleep(stall_duration)
case RateLimitStrategy.DISCARD.value:
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
case RateLimitStrategy.DISCARD:
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
return event.stop_event()
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))

View File

@@ -1,10 +1,7 @@
import random
import asyncio
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -12,19 +9,6 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间:{self.interval}")
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -32,16 +16,7 @@ class RespondStage(Stage):
return
if len(result.chain) > 0:
await event._pre_send()
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
# 分段回复
for comp in result.chain:
await event.send(MessageChain([comp]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
else:
await event.send(result)
await event._post_send()
await event.send(result)
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)

View File

@@ -1,13 +1,11 @@
import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record
from astrbot.core.message.components import Plain, Image, At, Reply
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -18,13 +16,7 @@ class ResultDecorateStage:
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.seg_prompt = ctx.astrbot_config['platform_settings']['segmented_reply']['seg_prompt']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
self.t2i = ctx.astrbot_config['t2i']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -39,67 +31,10 @@ class ResultDecorateStage:
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
for comp in result.chain:
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
if self.seg_prompt:
try:
llm_resp = await self.ctx.plugin_manager.context.get_using_provider().text_chat(
prompt=f"{self.seg_prompt}\n{comp.text}",
)
comp.text = llm_resp.completion_text
except BaseException as e:
traceback.print_exc()
logger.error("使用 LLM 分段回复失败: " + str(e))
new_chain.append(comp)
continue
split_response = re.findall(self.regex, comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
if seg:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
if self.use_tts and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + comp.text)
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
result.chain.insert(0, Plain(self.reply_prefix))
# 文本转图片
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
@@ -118,12 +53,8 @@ class ResultDecorateStage:
if url:
result.chain = [Image.fromURL(url)]
# at 回复
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
result.chain[1].text = "\n" + result.chain[1].text
result.chain.insert(0, At(qq=event.get_sender_id()))
# 引用回复
if self.reply_with_quote:
result.chain.insert(0, Reply(id=event.message_obj.message_id))

View File

@@ -2,11 +2,11 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
@register_stage
class WakingCheckStage(Stage):
@@ -21,9 +21,6 @@ class WakingCheckStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
)
async def process(
self, event: AstrMessageEvent
@@ -80,9 +77,7 @@ class WakingCheckStage(Stage):
# filter 需要满足 AND 的逻辑关系
passed = True
child_command_handler_md = None
permission_not_pass = False
if len(handler.event_filters) == 0:
# 不可能有这种情况, 也不允许有这种情况
continue
@@ -99,9 +94,6 @@ class WakingCheckStage(Stage):
else:
handler = child_command_handler_md # handler 覆盖
break
elif isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -119,13 +111,6 @@ class WakingCheckStage(Stage):
break
if passed:
if permission_not_pass:
if self.no_permission_reply:
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足"))
event.stop_event()
return
is_wake = True
event.is_wake = True

View File

@@ -18,11 +18,6 @@ class WhitelistCheckStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
# 白名单检查未启用
return
if len(self.whitelist) == 0:
# 白名单为空,不检查
return
if event.get_platform_name() == 'webchat':

View File

@@ -179,15 +179,6 @@ class AstrMessageEvent(abc.ABC):
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
'''设置消息事件的结果。
@@ -296,7 +287,6 @@ class AstrMessageEvent(abc.ABC):
def request_llm(
self,
prompt: str,
func_tool_manager = None,
session_id: str = None,
image_urls: List[str] = None,
contexts: List = None,
@@ -312,13 +302,11 @@ class AstrMessageEvent(abc.ABC):
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
'''
return ProviderRequest(
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
func_tool = func_tool_manager,
contexts = contexts,
system_prompt = system_prompt
)

View File

@@ -24,12 +24,11 @@ class PlatformManager():
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
try:
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
except BaseException:
logger.warning("当前 astrbot 已不维护 vchat 的接入,如有需要请 pip 安装 vchat 然后重启")
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "mispeaker":
from .sources.mispeaker.mispeaker_adapter import MiSpeakerPlatformAdapter # noqa: F401
async def initialize(self):

View File

@@ -3,7 +3,7 @@ import random
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -20,18 +20,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, (Image, Record)):
if isinstance(segment, Image):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_base64 = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
else:
bs64_data = file_to_base64(segment.file)
image_base64 = file_to_base64(image_file_path)
d['data'] = {
'file': bs64_data,
'file': image_base64,
}
ret.append(d)
return ret
@@ -40,5 +38,11 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
await self.bot.send(self.message_obj.raw_message, ret)
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
else:
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)

View File

@@ -102,7 +102,7 @@ class AiocqhttpAdapter(Platform):
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
@@ -122,10 +122,7 @@ class AiocqhttpAdapter(Platform):
def run(self) -> Awaitable[Any]:
if not self.host or not self.port:
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port将使用默认值http://0.0.0.0:6199")
self.host = "0.0.0.0"
self.port = 6199
return
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):

View File

@@ -2,14 +2,10 @@ import threading
import asyncio
import aiohttp
import quart
import base64
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api.message_components import Plain, Image, At
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.core.utils.io import download_image_by_url
class SimpleGewechatClient():
'''针对 Gewechat 的简单实现。
@@ -21,15 +17,9 @@ class SimpleGewechatClient():
self.base_url = base_url
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
@@ -37,19 +27,15 @@ class SimpleGewechatClient():
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.callback_url = None
self.server = quart.Quart(__name__)
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
@@ -66,89 +52,57 @@ class SimpleGewechatClient():
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
abm = AstrBotMessage()
d = data['Data']
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
msg_type = d['MsgType']
abm.message_id = str(d.get('MsgId'))
abm.session_id = from_user_name
abm.self_id = data['Wxid'] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
if '在群聊中@了你' in d.get('PushContent', ''):
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
.replace('在群聊中@了你', '') \
.replace('在群聊中发了一段语音', '') # 真实昵称
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
# 不同消息类型
match d['MsgType']:
match msg_type:
case 1:
# 文本消息
abm.message.append(Plain(content))
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
user_real_name = d['PushContent'].split(' : ')[0] # 真实昵称
user_real_name = user_real_name.replace('在群聊中@了你', '') # trick
abm.self_id = data['Wxid'] # 机器人的 wxid
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.session_id = from_user_name
abm.sender = MessageMember(user_id, user_real_name)
abm.message = [Plain(content)]
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
abm.message_id = str(d['MsgId'])
abm.raw_message = d
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid,
content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
with open(file_path, "wb") as f:
f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
logger.info(f"abm: {abm}")
return abm
case _:
logger.error(f"未实现的消息类型: {d['MsgType']}")
return
logger.info(f"abm: {abm}")
return abm
logger.error(f"未实现的消息类型: {msg_type}")
async def callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
@@ -156,45 +110,42 @@ class SimpleGewechatClient():
if data.get('testMsg', None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}")
abm = await self._convert(data)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={
"token": self.token,
"callbackUrl": self.callback_url
"callbackUrl": callback_url
}
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob['ret'] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
logger.info(f"将在 {callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
async def start_polling(self):
# 设置回调
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host='0.0.0.0',
host=self.host,
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder
)
@@ -235,8 +186,6 @@ class SimpleGewechatClient():
async def login(self):
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
if self.appid:
online = await self.check_online(self.appid)
@@ -296,7 +245,7 @@ class SimpleGewechatClient():
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{self.nickname}", appid)
sp.put(f"gewechat-appid-{nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
@@ -314,39 +263,4 @@ class SimpleGewechatClient():
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
logger.info(f"发送消息结果: {json_blob}")

View File

@@ -1,51 +0,0 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader():
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {
"Content-Type": "application/json",
"X-GEWE-TOKEN": token
}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}",
headers=self.headers,
json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {
"appId": appid,
"xml": xml,
"msgId": msg_id
}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
'''返回一个可下载的 URL'''
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {
"appId": appid,
"xml": xml,
"type": choice
}
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
json_blob = json.loads(data)
if 'fileUrl' in json_blob['data']:
return self.download_base_url + json_blob['data']['fileUrl']
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")

View File

@@ -1,24 +1,12 @@
import wave
import uuid
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
import random
import asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
@@ -46,57 +34,5 @@ class GewechatPlatformEvent(AstrMessageEvent):
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath('data/temp')
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
duration = await wav_to_tencent_silk(record_path, silk_path)
print(f"duration: {duration}, {silk_path}")
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# temp_directory = os.path.abspath('data/temp')
# record_path = os.path.abspath(record_path)
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
# with open(record_path, "rb") as f:
# record_path = f"data/temp/{uuid.uuid4()}.wav"
# with open(record_path, "wb") as f2:
# f2.write(f.read())
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
await self.client.post_voice(to_wxid, record_url, duration*1000)
await super().send(message)

View File

@@ -0,0 +1,137 @@
import os
import json
import asyncio
import aiohttp
import time
import traceback
from .miservice import MiAccount, MiNAService, MiIOService, miio_command, miio_command_help
from astrbot.core import logger
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At
class SimpleMiSpeakerClient():
'''
@author: Soulter
@references: https://github.com/yihong0618/xiaogpt/blob/main/xiaogpt/xiaogpt.py
'''
def __init__(self, config: dict):
self.username = config['username']
self.password = config['password']
self.did = config['did']
self.store = os.path.join("data", '.mi.token')
self.interval = float(config.get('interval', 1))
self.conv_query_cookies = {
'userId': '',
'deviceId': '',
'serviceToken': ''
}
self.MI_CONVERSATION_URL = "https://userprofile.mina.mi.com/device_profile/v2/conversation?source=dialogu&hardware={hardware}&timestamp={timestamp}&limit=1"
self.session = aiohttp.ClientSession()
self.activate_word = config.get('activate_word', '测试')
self.deactivate_word = config.get('deactivate_word', '停止')
self.entered = False
async def initialize(self):
account = MiAccount(self.session, self.username, self.password, self.store)
self.miio_service = MiIOService(account) # 小米设备服务
self.mina_service = MiNAService(account) # 小爱音箱服务
device = await self.get_mina_device()
self.deviceID = device['deviceID']
self.hardware = device['hardware']
with open(self.store, 'r') as f:
data = json.load(f)
self.userId = data['userId']
self.serviceToken = data['micoapi'][1]
self.conv_query_cookies['userId'] = self.userId
self.conv_query_cookies['deviceId'] = self.deviceID
self.conv_query_cookies['serviceToken'] = self.serviceToken
logger.info(f"MiSpeakerClient initialized. Conv cookies: {self.conv_query_cookies}. Hardware: {self.hardware}")
async def get_mina_device(self) -> dict:
devices = await self.mina_service.device_list()
for device in devices:
if device['miotDID'] == self.did:
logger.info(f"找到设备 {device['alias']}({device['name']}) 了!")
return device
async def get_conv(self) -> str:
# 时区请确保为北京时间
async with aiohttp.ClientSession() as session:
session.cookie_jar.update_cookies(self.conv_query_cookies)
query_ts = int(time.time())*1000
logger.debug(f"Querying conversation at {query_ts}")
async with session.get(self.MI_CONVERSATION_URL.format(hardware=self.hardware, timestamp=str(query_ts))) as resp:
json_blob = await resp.json()
if json_blob['code'] == 0:
data = json.loads(json_blob['data'])
records = data.get('records', None)
for record in records:
if record['time'] >= query_ts - self.interval*1000:
return record['query']
else:
logger.error(f"Failed to get conversation: {json_blob}")
return None
async def start_pooling(self):
while True:
await asyncio.sleep(self.interval)
try:
query = await self.get_conv()
if not query:
continue
# is wake
if query == self.activate_word:
self.entered = True
await self.stop_playing()
await self.send("我来啦!")
continue
elif query == self.deactivate_word:
self.entered = False
await self.stop_playing()
await self.send("再见,欢迎给个 Star。")
continue
if not self.entered:
continue
await self.send("")
abm = await self._convert(query)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
except BaseException as e:
traceback.print_exc()
logger.error(e)
async def _convert(self, query: str):
abm = AstrBotMessage()
abm.message = [Plain(query)]
abm.message_id = str(int(time.time()))
abm.message_str = query
abm.raw_message = query
abm.session_id = f"{self.hardware}_{self.did}_{self.username}"
abm.sender = MessageMember(self.username, "主人")
abm.self_id = f"{self.hardware}_{self.did}"
abm.type = MessageType.FRIEND_MESSAGE
return abm
async def send(self, message: str):
text = f'5 {message}'
await miio_command(self.miio_service, self.did, text, 'astrbot')
async def stop_playing(self):
text = f'3-2'
await miio_command(self.miio_service, self.did, text, 'astrbot')

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021-2022 Yonsm
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,5 @@
from .miaccount import MiAccount, MiTokenStore
from .minaservice import MiNAService
from .miioservice import MiIOService
from .miiocommand import miio_command, miio_command_help

View File

@@ -0,0 +1,135 @@
import base64
import hashlib
import json
import logging
import os
import random
import string
from urllib import parse
from aiohttp import ClientSession
from aiofiles import open as async_open
_LOGGER = logging.getLogger(__package__)
def get_random(length):
return ''.join(random.sample(string.ascii_letters + string.digits, length))
class MiTokenStore:
def __init__(self, token_path):
self.token_path = token_path
async def load_token(self):
if os.path.isfile(self.token_path):
try:
async with async_open(self.token_path) as f:
return json.loads(await f.read())
except Exception as e:
_LOGGER.exception("Exception on load token from %s: %s", self.token_path, e)
return None
async def save_token(self, token=None):
if token:
try:
async with async_open(self.token_path, 'w') as f:
await f.write(json.dumps(token, indent=2))
except Exception as e:
_LOGGER.exception("Exception on save token to %s: %s", self.token_path, e)
elif os.path.isfile(self.token_path):
os.remove(self.token_path)
class MiAccount:
def __init__(self, session: ClientSession, username, password, token_store='.mi.token'):
self.session = session
self.username = username
self.password = password
self.token_store = MiTokenStore(token_store) if isinstance(token_store, str) else token_store
self.token = None
async def login(self, sid):
if not self.token:
self.token = {'deviceId': get_random(16).upper()}
try:
resp = await self._serviceLogin(f'serviceLogin?sid={sid}&_json=true')
if resp['code'] != 0:
data = {
'_json': 'true',
'qs': resp['qs'],
'sid': resp['sid'],
'_sign': resp['_sign'],
'callback': resp['callback'],
'user': self.username,
'hash': hashlib.md5(self.password.encode()).hexdigest().upper()
}
resp = await self._serviceLogin('serviceLoginAuth2', data)
if resp['code'] != 0:
raise Exception(resp)
self.token['userId'] = resp['userId']
self.token['passToken'] = resp['passToken']
serviceToken = await self._securityTokenService(resp['location'], resp['nonce'], resp['ssecurity'])
self.token[sid] = (resp['ssecurity'], serviceToken)
if self.token_store:
await self.token_store.save_token(self.token)
return True
except Exception as e:
self.token = None
if self.token_store:
await self.token_store.save_token()
_LOGGER.exception("Exception on login %s: %s", self.username, e)
return False
async def _serviceLogin(self, uri, data=None):
headers = {'User-Agent': 'APP/com.xiaomi.mihome APPV/6.0.103 iosPassportSDK/3.9.0 iOS/14.4 miHSTS'}
cookies = {'sdkVersion': '3.9', 'deviceId': self.token['deviceId']}
if 'passToken' in self.token:
cookies['userId'] = self.token['userId']
cookies['passToken'] = self.token['passToken']
url = 'https://account.xiaomi.com/pass/' + uri
async with self.session.request('GET' if data is None else 'POST', url, data=data, cookies=cookies, headers=headers) as r:
raw = await r.read()
resp = json.loads(raw[11:])
_LOGGER.debug("%s: %s", uri, resp)
return resp
async def _securityTokenService(self, location, nonce, ssecurity):
nsec = 'nonce=' + str(nonce) + '&' + ssecurity
clientSign = base64.b64encode(hashlib.sha1(nsec.encode()).digest()).decode()
async with self.session.get(location + '&clientSign=' + parse.quote(clientSign)) as r:
serviceToken = r.cookies['serviceToken'].value
if not serviceToken:
raise Exception(await r.text())
return serviceToken
async def mi_request(self, sid, url, data, headers, relogin=True):
if self.token is None and self.token_store is not None:
self.token = await self.token_store.load_token()
if (self.token and sid in self.token) or await self.login(sid): # Ensure login
cookies = {'userId': self.token['userId'], 'serviceToken': self.token[sid][1]}
content = data(self.token, cookies) if callable(data) else data
method = 'GET' if data is None else 'POST'
_LOGGER.debug("%s %s", url, content)
async with self.session.request(method, url, data=content, cookies=cookies, headers=headers) as r:
status = r.status
if status == 200:
resp = await r.json(content_type=None)
code = resp['code']
if code == 0:
return resp
if 'auth' in resp.get('message', '').lower():
status = 401
else:
resp = await r.text()
if status == 401 and relogin:
_LOGGER.warn("Auth error on request %s %s, relogin...", url, resp)
self.token = None # Auth error, reset login
return await self.mi_request(sid, url, data, headers, False)
else:
resp = "Login failed"
raise Exception(f"Error {url}: {resp}")

View File

@@ -0,0 +1,104 @@
import json
from .miioservice import MiIOService
def twins_split(string, sep, default=None):
pos = string.find(sep)
return (string, default) if pos == -1 else (string[0:pos], string[pos+1:])
def string_to_value(string):
if string[0] in '"\'#':
return string[1:-1] if string[-1] in '"\'#' else string[1:]
elif string == 'null':
return None
elif string == 'false':
return False
elif string == 'true':
return True
elif string.isdigit():
return int(string)
try:
return float(string)
except:
return string
def miio_command_help(did=None, prefix='?'):
quote = '' if prefix == '?' else "'"
return f'\
Get Props: {prefix}<siid[-piid]>[,...]\n\
{prefix}1,1-2,1-3,1-4,2-1,2-2,3\n\
Set Props: {prefix}<siid[-piid]=[#]value>[,...]\n\
{prefix}2=60,2-1=#60,2-2=false,2-3="null",3=test\n\
Do Action: {prefix}<siid[-piid]> <arg1|[]> [...] \n\
{prefix}2 []\n\
{prefix}5 Hello\n\
{prefix}5-4 Hello 1\n\n\
Call MIoT: {prefix}<cmd=prop/get|/prop/set|action> <params>\n\
{prefix}action {quote}{{"did":"{did or "267090026"}","siid":5,"aiid":1,"in":["Hello"]}}{quote}\n\n\
Call MiIO: {prefix}/<uri> <data>\n\
{prefix}/home/device_list {quote}{{"getVirtualModel":false,"getHuamiDevices":1}}{quote}\n\n\
Devs List: {prefix}list [name=full|name_keyword] [getVirtualModel=false|true] [getHuamiDevices=0|1]\n\
{prefix}list Light true 0\n\n\
MIoT Spec: {prefix}spec [model_keyword|type_urn] [format=text|python|json]\n\
{prefix}spec\n\
{prefix}spec speaker\n\
{prefix}spec xiaomi.wifispeaker.lx04\n\
{prefix}spec urn:miot-spec-v2:device:speaker:0000A015:xiaomi-lx04:1\n\n\
MIoT Decode: {prefix}decode <ssecurity> <nonce> <data> [gzip]\n\
'
async def miio_command(service: MiIOService, did, text, prefix='?'):
cmd, arg = twins_split(text, ' ')
if cmd.startswith('/'):
return await service.miio_request(cmd, arg)
if cmd.startswith('prop') or cmd == 'action':
return await service.miot_request(cmd, json.loads(arg) if arg else None)
argv = arg.split(' ') if arg else []
argc = len(argv)
if cmd == 'list':
return await service.device_list(argc > 0 and argv[0], argc > 1 and string_to_value(argv[1]), argc > 2 and argv[2])
if cmd == 'spec':
return await service.miot_spec(argc > 0 and argv[0], argc > 1 and argv[1])
if cmd == 'decode':
return MiIOService.miot_decode(argv[0], argv[1], argv[2], argc > 3 and argv[3] == 'gzip')
if not did or not cmd or cmd == '?' or cmd == '' or cmd == 'help' or cmd == '-h' or cmd == '--help':
return miio_command_help(did, prefix)
if not did.isdigit():
devices = await service.device_list(did)
if not devices:
return "Device not found: " + did
did = devices[0]['did']
props = []
setp = True
miot = True
for item in cmd.split(','):
key, value = twins_split(item, '=')
siid, iid = twins_split(key, '-', '1')
if siid.isdigit() and iid.isdigit():
prop = [int(siid), int(iid)]
else:
prop = [key]
miot = False
if value is None:
setp = False
elif setp:
prop.append(string_to_value(value))
props.append(prop)
if miot and argc > 0:
args = [] if arg == '[]' else [string_to_value(a) for a in argv]
return await service.miot_action(did, props[0], args)
do_props = ((service.home_get_props, service.miot_get_props), (service.home_set_props, service.miot_set_props))[setp][miot]
return await do_props(did, props if miot or setp else [p[0] for p in props])

View File

@@ -0,0 +1,197 @@
import os
import time
import base64
import hashlib
import hmac
import json
# REGIONS = ['cn', 'de', 'i2', 'ru', 'sg', 'us']
class MiIOService:
def __init__(self, account=None, region=None):
self.account = account
self.server = 'https://' + ('' if region is None or region == 'cn' else region + '.') + 'api.io.mi.com/app'
async def miio_request(self, uri, data):
def prepare_data(token, cookies):
cookies['PassportDeviceId'] = token['deviceId']
return MiIOService.sign_data(uri, data, token['xiaomiio'][0])
headers = {'User-Agent': 'iOS-14.4-6.0.103-iPhone12,3--D7744744F7AF32F0544445285880DD63E47D9BE9-8816080-84A3F44E137B71AE-iPhone', 'x-xiaomi-protocal-flag-cli': 'PROTOCAL-HTTP2'}
resp = await self.account.mi_request('xiaomiio', self.server + uri, prepare_data, headers)
if 'result' not in resp:
raise Exception(f"Error {uri}: {resp}")
return resp['result']
async def home_request(self, did, method, params):
return await self.miio_request('/home/rpc/' + did, {'id': 1, 'method': method, "accessKey": "IOS00026747c5acafc2", 'params': params})
async def home_get_props(self, did, props):
return await self.home_request(did, 'get_prop', props)
async def home_set_props(self, did, props):
return [await self.home_set_prop(did, i[0], i[1]) for i in props]
async def home_get_prop(self, did, prop):
return (await self.home_get_props(did, [prop]))[0]
async def home_set_prop(self, did, prop, value):
result = (await self.home_request(did, 'set_' + prop, value if isinstance(value, list) else [value]))[0]
return 0 if result == 'ok' else result
async def miot_request(self, cmd, params):
return await self.miio_request('/miotspec/' + cmd, {'params': params})
async def miot_get_props(self, did, iids):
params = [{'did': did, 'siid': i[0], 'piid': i[1]} for i in iids]
result = await self.miot_request('prop/get', params)
return [it.get('value') if it.get('code') == 0 else None for it in result]
async def miot_set_props(self, did, props):
params = [{'did': did, 'siid': i[0], 'piid': i[1], 'value': i[2]} for i in props]
result = await self.miot_request('prop/set', params)
return [it.get('code', -1) for it in result]
async def miot_get_prop(self, did, iid):
return (await self.miot_get_props(did, [iid]))[0]
async def miot_set_prop(self, did, iid, value):
return (await self.miot_set_props(did, [(iid[0], iid[1], value)]))[0]
async def miot_action(self, did, iid, args=[]):
result = await self.miot_request('action', {'did': did, 'siid': iid[0], 'aiid': iid[1], 'in': args})
return result.get('code', -1)
async def device_list(self, name=None, getVirtualModel=False, getHuamiDevices=0):
result = await self.miio_request('/home/device_list', {'getVirtualModel': bool(getVirtualModel), 'getHuamiDevices': int(getHuamiDevices)})
result = result['list']
return result if name == 'full' else [{'name': i['name'], 'model': i['model'], 'did': i['did'], 'token': i['token']} for i in result if not name or name in i['name']]
async def miot_spec(self, type=None, format=None):
if not type or not type.startswith('urn'):
def get_spec(all):
if not type:
return all
ret = {}
for m, t in all.items():
if type == m:
return {m: t}
elif type in m:
ret[m] = t
return ret
import tempfile
path = os.path.join(tempfile.gettempdir(), 'miservice_miot_specs.json')
try:
with open(path) as f:
result = get_spec(json.load(f))
except:
result = None
if not result:
async with self.account.session.get('http://miot-spec.org/miot-spec-v2/instances?status=all') as r:
all = {i['model']: i['type'] for i in (await r.json())['instances']}
with open(path, 'w') as f:
json.dump(all, f)
result = get_spec(all)
if len(result) != 1:
return result
type = list(result.values())[0]
url = 'http://miot-spec.org/miot-spec-v2/instance?type=' + type
async with self.account.session.get(url) as r:
result = await r.json()
def parse_desc(node):
desc = node['description']
# pos = desc.find(' ')
# if pos != -1:
# return (desc[:pos], ' # ' + desc[pos + 2:])
name = ''
for i in range(len(desc)):
d = desc[i]
if d in '-—{「[【(<《':
return (name, ' # ' + desc[i:])
name += '_' if d == ' ' else d
return (name, '')
def make_line(siid, iid, desc, comment, readable=False):
value = f"({siid}, {iid})" if format == 'python' else iid
return f" {'' if readable else '_'}{desc} = {value}{comment}\n"
if format != 'json':
STR_HEAD, STR_SRV, STR_VALUE = ('from enum import Enum\n\n', '\nclass {}(tuple, Enum):\n', '\nclass {}(int, Enum):\n') if format == 'python' else ('', '{} = {}\n', '{}\n')
text = '# Generated by https://github.com/Yonsm/MiService\n# ' + url + '\n\n' + STR_HEAD
svcs = []
vals = []
for s in result['services']:
siid = s['iid']
svc = s['description'].replace(' ', '_')
svcs.append(svc)
text += STR_SRV.format(svc, siid)
for p in s.get('properties', []):
name, comment = parse_desc(p)
access = p['access']
comment += ''.join([' # ' + k for k, v in [(p['format'], 'string'), (''.join([a[0] for a in access]), 'r')] if k and k != v])
text += make_line(siid, p['iid'], name, comment, 'read' in access)
if 'value-range' in p:
valuer = p['value-range']
length = min(3, len(valuer))
values = {['MIN', 'MAX', 'STEP'][i]: valuer[i] for i in range(length) if i != 2 or valuer[i] != 1}
elif 'value-list' in p:
values = {i['description'].replace(' ', '_') if i['description'] else str(i['value']): i['value'] for i in p['value-list']}
else:
continue
vals.append((svc + '_' + name, values))
if 'actions' in s:
text += '\n'
for a in s['actions']:
name, comment = parse_desc(a)
comment += ''.join([f" # {io}={a[io]}" for io in ['in', 'out'] if a[io]])
text += make_line(siid, a['iid'], name, comment)
text += '\n'
for name, values in vals:
text += STR_VALUE.format(name)
for k, v in values.items():
text += f" {'_' + k if k.isdigit() else k} = {v}\n"
text += '\n'
if format == 'python':
text += '\nALL_SVCS = (' + ', '.join(svcs) + ')\n'
result = text
return result
@staticmethod
def miot_decode(ssecurity, nonce, data, gzip=False):
from Crypto.Cipher import ARC4
r = ARC4.new(base64.b64decode(MiIOService.sign_nonce(ssecurity, nonce)))
r.encrypt(bytes(1024))
decrypted = r.encrypt(base64.b64decode(data))
if gzip:
try:
from io import BytesIO
from gzip import GzipFile
compressed = BytesIO()
compressed.write(decrypted)
compressed.seek(0)
decrypted = GzipFile(fileobj=compressed, mode='rb').read()
except:
pass
return json.loads(decrypted.decode())
@staticmethod
def sign_nonce(ssecurity, nonce):
m = hashlib.sha256()
m.update(base64.b64decode(ssecurity))
m.update(base64.b64decode(nonce))
return base64.b64encode(m.digest()).decode()
@staticmethod
def sign_data(uri, data, ssecurity):
if not isinstance(data, str):
data = json.dumps(data)
nonce = base64.b64encode(os.urandom(8) + int(time.time() / 60).to_bytes(4, 'big')).decode()
snonce = MiIOService.sign_nonce(ssecurity, nonce)
msg = '&'.join([uri, snonce, nonce, 'data=' + data])
sign = hmac.new(key=base64.b64decode(snonce), msg=msg.encode(), digestmod=hashlib.sha256).digest()
return {'_nonce': nonce, 'data': data, 'signature': base64.b64encode(sign).decode()}

View File

@@ -0,0 +1,50 @@
import json
from .miaccount import MiAccount, get_random
import logging
_LOGGER = logging.getLogger(__package__)
class MiNAService:
def __init__(self, account: MiAccount):
self.account = account
async def mina_request(self, uri, data=None):
requestId = 'app_ios_' + get_random(30)
if data is not None:
data['requestId'] = requestId
else:
uri += '&requestId=' + requestId
headers = {'User-Agent': 'MiHome/6.0.103 (com.xiaomi.mihome; build:6.0.103.1; iOS 14.4.0) Alamofire/6.0.103 MICO/iOSApp/appStore/6.0.103'}
return await self.account.mi_request('micoapi', 'https://api2.mina.mi.com' + uri, data, headers)
async def device_list(self, master=0):
result = await self.mina_request('/admin/v2/device_list?master=' + str(master))
return result.get('data') if result else None
async def ubus_request(self, deviceId, method, path, message):
message = json.dumps(message)
result = await self.mina_request('/remote/ubus', {'deviceId': deviceId, 'message': message, 'method': method, 'path': path})
return result and result.get('code') == 0
async def text_to_speech(self, deviceId, text):
return await self.ubus_request(deviceId, 'text_to_speech', 'mibrain', {'text': text})
async def player_set_volume(self, deviceId, volume):
return await self.ubus_request(deviceId, 'player_set_volume', 'mediaplayer', {'volume': volume, 'media': 'app_ios'})
async def send_message(self, devices, devno, message, volume=None): # -1/0/1...
result = False
for i in range(0, len(devices)):
if devno == -1 or devno != i + 1 or devices[i]['capabilities'].get('yunduantts'):
_LOGGER.debug("Send to devno=%d index=%d: %s", devno, i, message or volume)
deviceId = devices[i]['deviceID']
result = True if volume is None else await self.player_set_volume(deviceId, volume)
if result and message:
result = await self.text_to_speech(deviceId, message)
if not result:
_LOGGER.error("Send failed: %s", message or volume)
if devno != -1 or not result:
break
return result

View File

@@ -0,0 +1,63 @@
import logging
import time
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from astrbot.core.message.components import BaseMessageComponent
from .client import SimpleMiSpeakerClient
from .mispeaker_event import MiSpeakerPlatformEvent
from astrbot.core import logger
@register_platform_adapter("mispeaker", "小爱音箱")
class MiSpeakerPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
pass
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"mispeaker",
"小爱音箱",
)
async def handle_msg(self, message: AstrBotMessage):
message_event = MiSpeakerPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)
def run(self):
self.client = SimpleMiSpeakerClient(
self.config
)
async def on_event_received(abm: AstrBotMessage):
logger.info(f"on_event_received: {abm}")
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
return self._run()
async def _run(self):
await self.client.initialize()
await self.client.start_pooling()

View File

@@ -0,0 +1,30 @@
import random
import asyncio
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from .client import SimpleMiSpeakerClient
class MiSpeakerPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleMiSpeakerClient
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(message: MessageChain, user_name: str):
pass
async def send(self, message: MessageChain):
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.send(comp.text)
await super().send(message)

View File

@@ -5,7 +5,7 @@ import botpy.types.message
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Reply
from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
@@ -14,33 +14,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
'''QQ 官方 API 仅支持回复一次'''
source = self.message_obj.raw_message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
ref = None
for i in self.send_buffer.chain:
if isinstance(i, Reply):
try:
ref = self.message_obj.raw_message.message_reference
ref = botpy.types.message.Reference(
message_id=ref.message_id,
ignore_get_message_error=False
)
except BaseException as _:
pass
break
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
payload = {
'content': plain_text,
@@ -49,37 +28,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
match type(source):
case botpy.message.GroupMessage:
if ref:
payload['message_reference'] = ref
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if ref:
payload['message_reference'] = ref
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if ref:
payload['message_reference'] = ref
if image_path:
payload['file_image'] = image_path
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
case botpy.message.DirectMessage:
if ref:
payload['message_reference'] = ref
if image_path:
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(self.send_buffer)
self.send_buffer = None
await super().send(message)
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
payload = {
@@ -111,7 +80,4 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
return plain_text, image_base64, image_file_path

View File

@@ -32,10 +32,6 @@ class WebChatMessageEvent(AstrMessageEvent):
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
web_chat_back_queue.put_nowait(None)
await super().send(message)

View File

@@ -2,8 +2,6 @@ import enum
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from astrbot.core.db.po import Conversation
class ProviderType(enum.Enum):
@@ -39,15 +37,10 @@ class ProviderRequest():
'''上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
'''
system_prompt: str = ""
'''系统提示词'''
conversation: Conversation = None
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
def __str__(self):
return self.__repr__()
@dataclass
class LLMResponse:
@@ -58,7 +51,4 @@ class LLMResponse:
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
'''工具调用名称'''

View File

@@ -108,21 +108,14 @@ class FuncCall:
for f in self.func_list:
if not f.active:
continue
func_declaration = {
"name": f.name,
"description": f.description
}
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
if params.get("properties", {}):
func_declaration["parameters"] = params
tools.append(func_declaration)
if tools:
declarations["function_declarations"] = tools
tools.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
declarations["function_declarations"] = tools
return declarations

View File

@@ -1,7 +1,6 @@
import traceback
import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .provider import Provider, STTProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
@@ -14,11 +13,8 @@ class ProviderManager():
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
self.personas: List[Personality] = []
self.selected_default_persona = None
@@ -30,7 +26,7 @@ class ProviderManager():
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
begin_dialogs = []
continue
user_turn = True
for dialog in begin_dialogs:
bd_processed.append({
@@ -42,9 +38,9 @@ class ProviderManager():
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
mood_imitation_dialogs = []
continue
user_turn = True
for dialog in mood_imitation_dialogs:
for dialog in begin_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
@@ -63,33 +59,16 @@ class ProviderManager():
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed=""
)
self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
@@ -99,16 +78,12 @@ class ProviderManager():
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
changed = False
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}")
provider_cfg['id'] = new_id
changed = True
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
self.loaded_ids[provider_cfg['id']] = True
try:
@@ -128,37 +103,25 @@ class ProviderManager():
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
if changed:
try:
config.save_config()
except Exception as e:
logger.warning(f"保存配置文件失败:{e}")
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
async def initialize(self):
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
@@ -175,18 +138,6 @@ class ProviderManager():
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
@@ -216,18 +167,11 @@ class ProviderManager():
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst:
if self.provider_stt_settings.get("enable"):
if not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
def get_insts(self):
return self.provider_insts

View File

@@ -8,8 +8,6 @@ from typing import TypedDict
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
@@ -17,8 +15,8 @@ class Personality(TypedDict):
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
_begin_dialogs_processed: List[dict]
_mood_imitation_dialogs_processed: str
@dataclass
@@ -26,13 +24,40 @@ class ProviderMeta():
id: str
model: str
type: str
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
class Provider(abc.ABC):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
self.model_name = ""
'''当前使用的模型名称'''
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
'''维护了当前的使用的 persona即人格。可能为 None'''
self.db_helper = db_helper
'''用于持久化的数据库操作对象。'''
if persistant_history:
# 读取历史记录
try:
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
@@ -42,31 +67,6 @@ class AbstractProvider(abc.ABC):
'''获得当前使用的模型名称'''
return self.model_name
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class Provider(AbstractProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
'''维护了当前的使用的 persona即人格。可能为 None'''
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError()
@@ -84,6 +84,22 @@ class Provider(AbstractProvider):
'''获得支持的模型列表'''
raise NotImplementedError()
@abc.abstractmethod
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
'''获取人类可读的上下文
page 从 1 开始
Example:
["User: 你好", "Assistant: 你好!"]
Return:
contexts: List[str]: 上下文列表
total_pages: int: 总页数
'''
raise NotImplementedError()
@abc.abstractmethod
async def text_chat(self,
prompt: str,
@@ -97,40 +113,37 @@ class Provider(AbstractProvider):
Args:
prompt: 提示词
session_id: 会话 ID(此属性已经被废弃)
session_id: 会话 ID
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
kwargs: 其他参数
Notes:
- 如果传入了 contexts将会提前加上上下文。否则使用 session_memory 中的上下文。
- 可以选择性地传入 session_id如果传入了 session_id将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls将会在对话时附上图片。如果模型不支持图片输入将会抛出错误。
- 如果传入了 tools将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling将会抛出错误。
'''
raise NotImplementedError()
@abc.abstractmethod
async def forget(self, session_id: str) -> bool:
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
async def pop_record(self, context: List):
'''
弹出 context 第一条非系统提示词对话记录
'''
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
else:
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider(AbstractProvider):
class STTProvider():
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@@ -138,15 +151,19 @@ class STTProvider(AbstractProvider):
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
'''获取文本的音频,返回音频文件路径'''
raise NotImplementedError()
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获取当前使用的模型'''
return self.provider_config.get("model", "")
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)

View File

@@ -73,10 +73,60 @@ class ProviderGoogleGenAI(Provider):
api_base=provider_config.get("api_base", None)
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
return await self.client.models_list()
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
@@ -131,9 +181,6 @@ class ProviderGoogleGenAI(Provider):
)
logger.debug(f"result: {result}")
if "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]['content']['parts']
llm_response = LLMResponse("assistant")
for candidate in candidates:
@@ -143,47 +190,49 @@ class ProviderGoogleGenAI(Provider):
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate['functionCall']['args'])
llm_response.tools_call_name.append(candidate['functionCall']['name'])
llm_response.completion_text = llm_response.completion_text.strip()
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str = None,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=[],
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**model_config
**self.provider_config.get("model_config", {})
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 10
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
await self.pop_record(context_query)
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
@@ -191,19 +240,31 @@ class ProviderGoogleGenAI(Provider):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
elif "Function calling is not enabled" in str(e):
logger.info(f"{self.get_model()} 不支持函数调用工具调用,已经自动去除")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
raise e
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
def get_current_key(self) -> str:
return self.client.api_key

View File

@@ -57,13 +57,20 @@ class LLMTunerModelLoader(Provider):
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = [],
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
new_record = {"role": "user", "content": prompt}
query_context = [*contexts, new_record]
if not contexts:
query_context = [
*self.session_memory[session_id],
new_record,
]
system_prompt = self.curr_personality["prompt"]
else:
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
@@ -89,8 +96,33 @@ class LLMTunerModelLoader(Provider):
responses = await self.model.achat(**conf)
llm_response = LLMResponse("assistant", responses[-1].response_text)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id):
self.session_memory[session_id] = []
return True
async def get_current_key(self):
return "none"
@@ -99,4 +131,28 @@ class LLMTunerModelLoader(Provider):
pass
async def get_models(self):
return [self.get_model()]
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages

View File

@@ -1,8 +1,8 @@
import traceback
import base64
import json
import os
from openai import AsyncOpenAI, AsyncAzureOpenAI, NOT_GIVEN
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import NotFoundError
from astrbot.core.utils.io import download_image_by_url
@@ -29,25 +29,37 @@ class ProviderOpenAIOfficial(Provider):
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
# 适配 azure openai #332
if "api_version" in provider_config:
# 使用 azure api
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
else:
# 使用 openai api
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
try:
@@ -60,6 +72,32 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_openai_style()
@@ -70,9 +108,9 @@ class ProviderOpenAIOfficial(Provider):
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion}")
logger.debug(f"completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
@@ -81,8 +119,7 @@ class ProviderOpenAIOfficial(Provider):
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
return LLMResponse("assistant", completion_text, raw_completion=completion)
return LLMResponse("assistant", completion_text)
elif choice.message.tool_calls:
# tools call (function calling)
args_ls = []
@@ -93,49 +130,49 @@ class ProviderOpenAIOfficial(Provider):
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls, raw_completion=completion)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
else:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception("Internal Error")
async def text_chat(
self,
prompt: str,
session_id: str=None,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=[],
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**model_config
**self.provider_config.get("model_config", {})
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
# 重试 10 次
retry_cnt = 10
while retry_cnt > 0:
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
logger.warning(f"请求失败:{e}上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
await self.pop_record(session_id)
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
@@ -143,65 +180,33 @@ class ProviderOpenAIOfficial(Provider):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads['messages'] = new_contexts
llm_response = await self._query(payloads, func_tool)
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
elif 'does not support Function Calling' in str(e) \
or 'does not support tools' in str(e) \
or 'Function call is not supported' in str(e) \
or 'Function calling is not enabled' in str(e) \
or 'Tool calling is not supported' in str(e): # siliconcloud
logger.info(f"{self.get_model()} 不支持函数调用工具调用,已经自动去除")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if 'tool' in str(e).lower() and 'support' in str(e).lower():
logger.error(f"疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if 'Connection error.' in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}")
raise e
return llm_response
async def _remove_image_from_context(self, contexts: List):
'''
从上下文中删除所有带有 image 的记录
'''
new_contexts = []
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
flag = False
for context in contexts:
if flag:
flag = False # 删除 image 后下一条LLM 响应)也要删除
continue
if isinstance(context['content'], list):
flag = True
# continue
new_content = []
for item in context['content']:
if isinstance(item, dict) and 'image_url' in item:
continue
new_content.append(item)
if not new_content:
# 用户只发了图片
new_content = [{"type": "text", "text": "[图片]"}]
context['content'] = new_content
new_contexts.append(context)
return new_contexts
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
def get_current_key(self) -> str:
return self.client.api_key

View File

@@ -1,40 +0,0 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
class ProviderOpenAITTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("openai-tts-voice", "alloy")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def get_audio(self, text: str) -> str:
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format='wav',
input=text
) as response:
with open(path, 'wb') as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path

View File

@@ -22,21 +22,24 @@ class ProviderZhipu(ProviderOpenAIOfficial):
async def text_chat(
self,
prompt: str,
session_id: str = None,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=[],
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model()
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
@@ -59,6 +62,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
}
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):

View File

@@ -1,7 +1,3 @@
'''
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
'''
from typing import Union
import os
import json

View File

@@ -10,13 +10,12 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata, star_map
from .star import star_registry, StarMetadata
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class Context:
'''
@@ -45,7 +44,6 @@ class Context:
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
@@ -54,129 +52,21 @@ class Context:
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
self.conversation_manager = conversation_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
for star in star_registry:
if star.name == star_name:
return star
def get_all_stars(self) -> List[StarMetadata]:
'''获取当前载入的所有插件 Metadata 的列表'''
return star_registry
def get_llm_tool_manager(self) -> FuncCall:
'''获取 LLM Tool Manager其用于管理注册的所有的 Function-calling tools'''
'''
获取 LLM Tool Manager
'''
return self.provider_manager.llm_tools
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
Returns:
如果没找到,会返回 False
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def deactivate_llm_tool(self, name: str) -> bool:
'''停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''获取 AstrBot 的配置。'''
return self._config
def get_db(self) -> BaseDatabase:
'''获取 AstrBot 数据库。'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
'''
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
'''
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用function-calling / tools-use添加工具。
@@ -204,7 +94,41 @@ class Context:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
Returns:
如果没找到,会返回 False
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def deactivate_llm_tool(self, name: str) -> bool:
'''停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
@@ -238,6 +162,77 @@ class Context:
))
star_handlers_registry.append(md)
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider(Chat_Completion 类型)。
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider(Chat_Completion 类型)。
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
def register_task(self, task: Awaitable, desc: str):
'''
注册一个异步任务。

View File

@@ -46,11 +46,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
if not event.is_wake_up():
return False
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
message_str = event.get_message_str().strip()
# 分割为列表(每个参数之间可能会有多个空格)
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:

View File

@@ -40,24 +40,17 @@ class CommandGroupFilter(HandlerFilter):
if not event.is_wake_up():
return False, None
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
message_str = event.get_message_str().strip()
ls = re.split(r"\s+", message_str)
if ls[0] != self.group_name:
return False, None
# 改写 message_str
ls = ls[1:]
# event.message_str = " ".join(ls)
# event.message_str = event.message_str.strip()
parsing_command = " ".join(ls)
parsing_command = parsing_command.strip()
event.set_extra("parsing_command", parsing_command)
event.message_str = " ".join(ls)
event.message_str = event.message_str.strip()
if parsing_command == "":
if event.message_str == "":
# 当前还是指令组
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)

View File

@@ -19,8 +19,7 @@ class PermissionTypeFilter(HandlerFilter):
'''
if self.permission_type == PermissionType.ADMIN:
if not event.is_admin():
# event.stop_event()
# raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
return False
event.stop_event()
raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限执行此操作。")
return True

View File

@@ -7,7 +7,6 @@ from .star_handler import (
register_regex,
register_permission_type,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
register_on_decorating_result,
register_after_message_sent
@@ -22,7 +21,6 @@ __all__ = [
'register_regex',
'register_permission_type',
'register_on_llm_request',
'register_on_llm_response',
'register_llm_tool',
'register_on_decorating_result',
'register_after_message_sent'

View File

@@ -17,7 +17,7 @@ def get_handler_full_name(awaitable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False, **kwargs) -> StarHandlerMetadata:
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
@@ -30,17 +30,14 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
handler_name=handler.__name__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[],
event_filters=[]
)
if handler.__doc__:
md.desc = handler.__doc__.strip()
if not dont_add:
star_handlers_registry.append(md)
return md
def register_command(command_name: str = None, *args):
'''注册一个 Command.
'''
'''注册一个 Command'''
new_command = None
add_to_event_filters = False
@@ -65,8 +62,7 @@ def register_command(command_name: str = None, *args):
return decorator
def register_command_group(command_group_name: str = None, *args):
'''注册一个 CommandGroup
'''
'''注册一个 CommandGroup'''
new_group = None
add_to_event_filters = False
@@ -143,8 +139,6 @@ def register_on_llm_request():
Examples:
```py
from astrbot.api.provider import ProviderRequest
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
@@ -158,27 +152,6 @@ def register_on_llm_request():
return decorator
def register_on_llm_response():
'''当有 LLM 请求后的事件
Examples:
```py
from astrbot.api.provider import LLMResponse
@on_llm_response()
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
...
```
请务必接收两个参数event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
return awaitable
return decorator
def register_llm_tool(name: str = None):
'''为函数调用function-calling / tools-use添加工具。

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
@@ -12,7 +11,7 @@ star_map: Dict[str, StarMetadata] = {}
@dataclass
class StarMetadata:
'''
插件的元数据。
Star 的元数据。
'''
name: str
author: str # 插件作者
@@ -21,24 +20,21 @@ class StarMetadata:
repo: str = None # 插件仓库地址
star_cls_type: type = None
'''插件的类对象的类型'''
'''Star 的类对象的类型'''
module_path: str = None
'''插件的模块路径'''
'''Star 的模块路径'''
star_cls: object = None
'''插件的类对象'''
'''Star 的类对象'''
module: ModuleType = None
'''插件的模块对象'''
'''Star 的模块对象'''
root_dir_name: str = None
'''插件的目录名'''
'''Star 的根目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留插件'''
'''是否是 AstrBot 的保留 Star'''
activated: bool = True
'''是否被激活'''
config: AstrBotConfig = None
'''插件配置'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"

View File

@@ -47,7 +47,6 @@ class EventType(enum.Enum):
'''
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后
OnDecoratingResultEvent = enum.auto() # 发送消息前
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后

View File

@@ -2,14 +2,12 @@ import inspect
import functools
import os
import sys
import json
import traceback
import yaml
import logging
from types import ModuleType
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import DEFAULT_VALUE_MAP
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
@@ -19,8 +17,6 @@ from .star import star_registry, star_map
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
from .filter.permission import PermissionTypeFilter, PermissionType
class PluginManager:
def __init__(
self,
@@ -30,20 +26,13 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self
self.context._star_manager = self # 就这样吧,不想改了
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
'''存储插件的路径。即 data/plugins'''
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
'''存储插件配置的路径。data/config'''
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
'''保留插件的路径。在 packages 目录下'''
self.conf_schema_fname = "_conf_schema.json"
'''插件配置 Schema 文件名'''
def _get_classes(self, arg: ModuleType):
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers:
@@ -139,7 +128,7 @@ class PluginManager:
return metadata
async def reload(self):
'''扫描并加载所有的插件'''
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
@@ -161,15 +150,13 @@ class PluginManager:
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
# 导入插件模块,并尝试实例化插件类
# 导入 Star 模块,并尝试实例化 Star 类
for plugin_module in plugin_modules:
try:
module_str = plugin_module['module']
# module_path = plugin_module['module_path']
root_dir_name = plugin_module['pname'] # 插件的目录名
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
root_dir_name = plugin_module['pname']
reserved = plugin_module.get('reserved', False)
logger.info(f"正在载入插件 {root_dir_name} ...")
@@ -186,43 +173,21 @@ class PluginManager:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
if os.path.exists(plugin_schema_path):
# 加载插件配置
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
schema=json.loads(f.read())
)
if path in star_map:
# 通过装饰器的方式注册插件
metadata = star_map[path]
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(context=self.context)
else:
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 绑定 handler
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
for handler in related_handlers:
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, metadata.star_cls)
# 绑定 llm_tool handler
# llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
@@ -234,20 +199,16 @@ class PluginManager:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
classes = self._get_classes(module)
try:
obj = getattr(module, classes[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
if plugin_config:
try:
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
else:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -257,30 +218,10 @@ class PluginManager:
star_registry.append(metadata)
logger.debug(f"插件 {root_dir_name} 载入成功。")
# 禁用/启用插件
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 检查并且植入自定义的权限过滤器alter_cmd
for handler in star_handlers_registry.get_handlers_by_module_name(metadata.module_path):
if metadata.name in alter_cmd and handler.handler_name in alter_cmd[metadata.name]:
# 注入权限过滤器
cmd_type = alter_cmd[metadata.name][handler.handler_name].get("permission", "member")
found_permission_filter = False
for filter_ in handler.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
filter_.permission_type = PermissionType.ADMIN
else:
filter_.permission_type = PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
handler.event_filters.append(PermissionTypeFilter(PermissionType.ADMIN if cmd_type == "admin" else PermissionType.MEMBER))
logger.debug(f"插入权限过滤器 {cmd_type}{metadata.name}{handler.handler_name} 方法。")
# 执行 initialize() 方法
# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
@@ -351,14 +292,13 @@ class PluginManager:
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
@@ -383,9 +323,8 @@ class PluginManager:
plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
desti_dir = os.path.join(self.plugin_store_path, dir_name)
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
self.updator.unzip_file(zip_file_path, desti_dir)
# remove the zip
@@ -393,4 +332,6 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
await self.reload()
self._check_plugin_dept_update()

View File

@@ -6,8 +6,6 @@ import time
import aiohttp
import base64
import zipfile
import uuid
from typing import Union
from PIL import Image
@@ -43,21 +41,21 @@ def port_checker(port: int, host: str = "localhost"):
return False
def save_temp_img(img: Union[Image.Image, str]) -> str:
def save_temp_img(img: Image) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过 12 小时的
# 获得文件创建时间,清除超过1小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600*12:
if time.time() - ctime > 3600:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
timestamp = int(time.time())
p = f"data/temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
@@ -109,7 +107,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
'''
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, timeout=1800) as resp:
async with session.get(url, timeout=120) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get('content-length', 0))

View File

@@ -25,10 +25,6 @@ class ParameterValidationMixin:
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:

View File

@@ -20,23 +20,18 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
return output_path
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
'''返回 duration'''
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
import pysilk
with wave.open(wav_path, 'rb') as wav:
wav_data = wav.readframes(wav.getnframes())
wav_data = BytesIO(wav_data)
output_io = BytesIO()
pysilk.encode(wav_data, output_io, 24000, 24000)
pysilk.encode(wav_data, output_io, 24000)
output_io.seek(0)
# 在首字节添加 \x02,去除结尾的\xff\xff
# 在首字节添加 \x02
silk_data = output_io.read()
silk_data_with_prefix = b'\x02' + silk_data[:-2]
silk_data_with_prefix = b'\x02' + silk_data
# return BytesIO(silk_data_with_prefix)
with open(output_path, "wb") as f:
f.write(silk_data_with_prefix)
return 0
return BytesIO(silk_data_with_prefix)

View File

@@ -39,6 +39,7 @@ class RepoZipUpdator():
else:
ret = self.github_api_release_parser(result)
except BaseException:
logger.error("解析版本信息失败")
raise Exception("解析版本信息失败")
return ret

View File

@@ -121,7 +121,7 @@ class ChatRoute(Route):
}))
# 持久化
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
@@ -136,7 +136,7 @@ class ChatRoute(Route):
if audio_url:
new_his['audio_url'] = audio_url
history.append(new_his)
self.db.update_conversation(username, conversation_id, history=json.dumps(history))
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
return Response().ok().__dict__
@@ -168,7 +168,7 @@ class ChatRoute(Route):
continue
yield result_text + '\n'
conversation = self.db.get_conversation_by_user_id(username, cid)
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
@@ -178,7 +178,7 @@ class ChatRoute(Route):
'type': 'bot',
'message': result_text
})
self.db.update_conversation(username, cid, history=json.dumps(history))
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
await asyncio.sleep(0.5)
except BaseException as e:
@@ -204,20 +204,20 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_conversation(username, conversation_id)
self.db.delete_webchat_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.new_conversation(username, conversation_id)
self.db.webchat_new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_conversations(username)
conversations = self.db.get_webchat_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
@@ -226,7 +226,7 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
self.curr_user_cid[username] = conversation_id

View File

@@ -1,12 +1,14 @@
import os
import json
import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -17,9 +19,9 @@ def try_cast(value: str, type_: str):
elif type_ == "float" and isinstance(value, int):
return float(value)
def validate_config(data, schema: dict, is_core: bool):
def validate_config(data, config: AstrBotConfig):
errors = []
def validate(data, metadata=schema, path=""):
def validate(data, metadata=CONFIG_METADATA_2, path=""):
for key, meta in metadata.items():
if key not in data:
continue
@@ -54,33 +56,35 @@ def validate_config(data, schema: dict, is_core: bool):
elif meta["type"] == "object" and not isinstance(value, dict):
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
if is_core:
for key, group in schema.items():
group_meta = group.get("metadata")
if not group_meta:
continue
logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
validate(data)
return errors
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
'''验证并保存配置'''
errors = None
try:
if is_core:
errors = validate_config(post_config, CONFIG_METADATA_2, is_core)
else:
errors = validate_config(post_config, config.schema, is_core)
except BaseException as e:
logger.warning(f"验证配置时出现异常: {e}")
errors = validate_config(post_config, config)
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
def save_extension_config(post_config: dict):
if 'namespace' not in post_config:
raise ValueError("Missing key: namespace")
if 'config' not in post_config:
raise ValueError("Missing key: config")
namespace = post_config['namespace']
config: list = post_config['config'][0]['body']
for item in config:
key = item['path']
value = item['value']
typ = item['val_type']
if typ == 'int':
if not value.isdigit():
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
value = int(value)
update_config(namespace, key, value)
class ConfigRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
@@ -88,17 +92,17 @@ class ConfigRoute(Route):
self.routes = {
'/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs),
'/config/plugin/update': ('POST', self.post_extension_configs),
}
self.register_routes()
async def get_configs(self):
# plugin_name 为空时返回 AstrBot 配置
# 否则返回指定 plugin_name 的插件配置
plugin_name = request.args.get("plugin_name", None)
if not plugin_name:
# namespace 为空时返回 AstrBot 配置
# 否则返回指定 namespace 的插件配置
namespace = "" if "namespace" not in request.args else request.args["namespace"]
if not namespace:
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
return Response().ok(await self._get_extension_config(namespace)).__dict__
async def post_astrbot_configs(self):
post_configs = await request.json
@@ -106,15 +110,14 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
logger.error(e)
traceback.print_exc()
return Response().error(str(e)).__dict__
async def post_plugin_configs(self):
async def post_extension_configs(self):
post_configs = await request.json
plugin_name = request.args.get("plugin_name", "unknown")
try:
await self._save_plugin_configs(post_configs, plugin_name)
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
await self._save_extension_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
return Response().error(str(e)).__dict__
@@ -138,48 +141,28 @@ class ConfigRoute(Route):
"config": config
}
async def _get_plugin_config(self, plugin_name: str):
ret = {
"metadata": None,
"config": None
}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
if not plugin_md.config:
break
ret['config'] = plugin_md.config # 这是自定义的 Dict 类AstrBotConfig
ret['metadata'] = {
plugin_name: {
"description": f"{plugin_name} 配置",
"type": "object",
"items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema
}
}
break
return ret
async def _get_extension_config(self, namespace: str):
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
return [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
async def _save_astrbot_configs(self, post_configs: dict):
try:
save_config(post_configs, self.config, is_core=True)
save_astrbot_config(post_configs, self.config)
self.core_lifecycle.restart()
except Exception as e:
raise e
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
md = None
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
md = plugin_md
if not md:
raise ValueError(f"插件 {plugin_name} 不存在")
if not md.config:
raise ValueError(f"插件 {plugin_name} 没有注册配置")
async def _save_extension_configs(self, post_configs: dict):
try:
save_config(post_configs, md.config)
save_extension_config(post_configs)
self.core_lifecycle.restart()
except Exception as e:
raise e

View File

@@ -25,29 +25,15 @@ class PluginRoute(Route):
self.register_routes()
async def get_online_plugins(self):
custom = request.args.get("custom_registry")
if custom:
urls = [custom]
else:
urls = [
"https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json",
"https://api.soulter.top/astrbot/plugins"
]
for url in urls:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
return Response().ok(result).__dict__
else:
logger.error(f"请求 {url} 失败,状态码:{response.status}")
except Exception as e:
logger.error(f"请求 {url} 失败,错误:{e}")
return Response().error("获取插件列表失败").__dict__
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
result = await response.json()
return Response().ok(result).__dict__
except Exception as e:
logger.error(f"获取插件列表失败:{e}")
return Response().error(str(e)).__dict__
async def get_plugins(self):
_plugin_resp = []
@@ -70,7 +56,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -82,10 +67,9 @@ class PluginRoute(Route):
file = await request.files
file = file['file']
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}"
file_path = f"data/temp/{uuid.uuid4()}.zip"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
self.core_lifecycle.restart()
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -110,7 +94,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:

View File

@@ -6,10 +6,6 @@ class StaticFileRoute(Route):
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat']
for i in index_:
self.app.add_url_rule(i, view_func=self.index)
@self.app.errorhandler(404)
async def page_not_found(e):
return "404 Not found。如果你初次使用打开面板发现 404请参考文档: https://astrbot.app/deploy/dashboard-404.html"
async def index(self):
return await self.app.send_static_file('index.html')

View File

@@ -1,5 +1,4 @@
import traceback
import aiohttp
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -44,7 +43,7 @@ class UpdateRoute(Route):
}
).__dict__
except Exception as e:
logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)")
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
async def update_project(self):
@@ -64,13 +63,6 @@ class UpdateRoute(Route):
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
# pip 更新依赖
logger.info("更新依赖中...")
try:
pip_installer.install(requirements_path="requirements.txt")
except Exception as e:
logger.error(f"更新依赖失败: {e}")
if reboot:
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()

View File

@@ -1,6 +0,0 @@
# What's Changed
- Gewechat 微信支持图片、语音的收和发
- 支持 OpenAI TTS文字转语音
- 支持路径映射,解决 docker 部署时两端文件系统不一致导致的富媒体文件路径不存在问题
- Napcat 下语音消息可能接收异常

View File

@@ -1,4 +0,0 @@
# What's Changed
- 修复 astrbot_updator 属性缺失与stt_enabled 未初始化 #252
- 支持消息分段回复

View File

@@ -1,8 +0,0 @@
# What's Changed
- 修复: TTS 问题
- 新增: **支持记录非唤醒状态下群聊历史记录(beta)**
- 优化: 自动删除 deepseek-r1 模型自带的 think 标签
- 优化: 自动移除 ollama 不支持 tool 的模型的 tool 请求
- 优化: /t2i 即时生效
- 优化: gewechat 消息下发异常处理

View File

@@ -1,9 +0,0 @@
# What's Changed
- 修复: 配置 Validator 不起效的问题
- 修复: DeepSeek-R1 思考标签问题
- 修复: 分段回复间隔时间不生效
- 修复: 修复白名单为空时依然终止事件 #259
- 修复: 群聊增强某些参数的类型转换问题
- 新增: 插件支持注册配置,详见 [注册插件配置](https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta)
- 优化: 插件的禁用/启用逻辑以及函数工具的禁用/启用逻辑

View File

@@ -1,6 +0,0 @@
# What's Changed
- [gewechat] [修复每次启动astrbot都需要扫码的问题](https://github.com/Soulter/AstrBot/commit/fd5d7dd37a6d74f81a148bbebef8516aa0cb5540)
- [core] [Provider 重复时不直接报错闪退](https://github.com/Soulter/AstrBot/commit/b61f9be18db9a6b8b3c5b6b36553f66dd2b79375) https://github.com/Soulter/AstrBot/issues/265
- [core] [弱化更新报错](https://github.com/Soulter/AstrBot/commit/0ba0150fd8ff2062dbe83889163888ba3e33bd49) https://github.com/Soulter/AstrBot/issues/267
- 修复 webui 无法从本地上传插件的问题

View File

@@ -1,11 +0,0 @@
# What's Changed
- [beta] 支持群聊内基于概率的主动回复
- openai tts 更换模型 #300
- 增加模型响应后的插件钩子
- 修复 相同type的provider共享了记忆
- 优化 人格情景在发现格式不对时仍然加载而不是跳过 #282
- 修复 Gemini函数调用时parameters为空对象导致的错误 by @Camreishi
- 修复 弹出记录报错的问题 #272
- 优化 移除默认人格
- 优化 未启用模型提供商时的异常处理

View File

@@ -1,12 +0,0 @@
# What's Changed
- fix: 修复主动概率回复关闭后仍然回复的问题 #317
- fix: 尝试修复 gewechat 群聊收不到 at 的回复 #294
- perf: 移除了默认人格
- fix: 修复HTTP代理删除后不生效 #319
- fix: 调用Gemini API输出多余空行问题 #318
- feat: 添加硅基流动模版
- fix: 硅基流动 not a vlm 和 tool calling not supported 报错 #305 #291
- perf: 回复时艾特发送者之后添加空格或换行 #312
- fix: docker容器内时区不对导致 reminder 时间错误
- perf: siliconcloud 不支持 tool 的模型

View File

@@ -1,13 +0,0 @@
# What's Changed
1. 支持接入企业微信(测试)
2. 修复速率限制不可用的问题
3. gewechat 回调接口默认暴露在所有 IP
4. 适配 Azure OpenAI
5. 修复请求 gemini 出现 KeyError 'candidates' 的错误
6. 将 /reset /persona 挪入管理员指令 #308
7. 支持通过 /alter_cmd 设置所有指令是否只能管理员操作
8. /plugin 指令支持查看插件注册的指令和指令组
9. 插件注册指令支持传入指令的描述以方便 /plugin 查看。需要写在函数的第一行的 docstring 中。
10. 修复 schema 中 object hint 不显示 #290
11. feat: 优化插件市场的访问速度

View File

@@ -1,15 +0,0 @@
# What's Changed
> 由于重写了会话记录部分,更新此版本后,将会造成之前的对话记录清空(但没有被删除)。
> 关于更好的对话管理,如果有任何报错或者优化建议,请直接提交 issue~
1. 更好的对话管理,支持 /ls, /del, /new, /switch, /rename 指令来操作对话。
2. 人格情境跟随对话。每个对话支持独立设置人格情境,只需要 /persona 指令切换即可。
3. 支持使用 LLM 辅助分段回复 #338
4. 优化 aiocqhttp 适配器对用户非法输入的处理
5. 优化插件页面
6. 修复权限过滤算子导致的问题 #350
7. 修复级联指令组时出现载入错误的问题 #366
8. 修复代码执行器的一个typo by @eltociear
9. 修复指令组情况下可能造成多指令出触发的问题
10. 添加屏蔽无权限指令回复的功能 #361

View File

@@ -4,8 +4,6 @@ services:
astrbot:
image: soulter/astrbot:latest
container_name: astrbot
ports:
- "6180-6200:6180-6200"
- "11451:11451"
network_mode: "host"
volumes:
- ./data:/AstrBot/data
- ./data:/AstrBot/data

View File

@@ -4,7 +4,7 @@
</h3>
<v-card-text>
<div v-for="(index, key) in iterable" :key="key" style="margin-bottom: 0.5px;" v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template">
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint" style="margin-bottom: 16px"
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object'" style="margin-bottom: 16px"
:text="metadata[metadataKey].items[key]?.hint" :title="'💡 关于' + metadata[metadataKey].items[key]?.description"
type="info" variant="tonal">
</v-alert>
@@ -52,7 +52,7 @@
</div>
<div
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && !metadata[metadataKey].items[key]?.invisible">
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object' && !metadata[metadataKey].items[key]?.invisible">
<v-btn icon size="x-small" style="margin-bottom: 22px;">
<v-icon size="x-small">mdi-help</v-icon>
<v-tooltip activator="parent" location="start">{{ metadata[metadataKey].items[key]?.hint
@@ -63,7 +63,7 @@
</div>
<div v-else>
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint" style="margin-bottom: 16px"
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && metadata[metadataKey]?.type !== 'object'" style="margin-bottom: 16px"
:text="metadata[metadataKey]?.hint" :title="'💡 关于' + metadata[metadataKey]?.description"
type="info" variant="tonal">
</v-alert>
@@ -106,7 +106,7 @@
</div>
<div
v-if="!metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && !metadata[metadataKey]?.invisible">
v-if="!metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && metadata[metadataKey]?.type !== 'object' && !metadata[metadataKey]?.invisible">
<v-btn icon size="x-small" style="margin-bottom: 22px;">
<v-icon size="x-small">mdi-help</v-icon>
<v-tooltip activator="parent" location="start">{{ metadata[metadataKey]?.hint

View File

@@ -0,0 +1,41 @@
<script setup>
import UiParentCard from '@/components/shared/UiParentCard.vue';
const props = defineProps({
config: Array
});
</script>
<template>
<a v-show="config.length === 0">该插件没有配置</a>
<UiParentCard v-for="group in config" :key="group.name" :title="group.name" style="margin-bottom: 16px;">
<template v-for="item in group.body">
<template v-if="item.config_type === 'item'">
<template v-if="item.val_type === 'bool'">
<v-switch v-model="item.value" :label="item.name" :hint="item.description" color="primary" inset></v-switch>
</template>
<template v-else-if="item.val_type === 'str'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'int'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'list'">
<span>{{ item.name }}</span>
<v-combobox v-model="item.value" chips clearable label="请添加" multiple prepend-icon="mdi-tag-multiple-outline">
<template v-slot:selection="{ attrs, item, select, selected }">
<v-chip v-bind="attrs" :model-value="selected" closable @click="select" @click:close="remove(item)">
<strong>{{ item }}</strong>
</v-chip>
</template>
</v-combobox>
</template>
</template>
<template v-else-if="item.config_type === 'divider'">
<v-divider style="margin-top: 8px; margin-bottom: 8px;"></v-divider>
</template>
</template>
</UiParentCard>
</template>

View File

@@ -1,8 +1,7 @@
<script setup lang="ts">
const props = defineProps({
title: String,
link: String,
logo: String
link: String
});
const open = (link: string | undefined) => {
@@ -14,7 +13,6 @@ const open = (link: string | undefined) => {
<v-card variant="outlined" elevation="0" class="withbg">
<v-card-item style="padding: 10px 14px">
<div class="d-sm-flex align-center justify-space-between">
<img v-if="logo" :src="logo" alt="logo" style="width: 40px; height: 40px; margin-right: 8px;">
<v-card-title style="font-size: 17px;">{{ props.title }}</v-card-title>
<v-spacer></v-spacer>
<v-btn variant="plain" @click="open(props.link)">仓库</v-btn>

View File

@@ -1,7 +1,7 @@
<script setup>
import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import ConfigDetailCard from '@/components/shared/ConfigDetailCard.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
@@ -9,77 +9,54 @@ import axios from 'axios';
<template>
<v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败可以自行前往仓库下载压缩包然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README" title="💡提示"
type="info" variant="tonal">
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败可以自行前往仓库下载压缩包然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
</v-alert>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
<h3>🧩 已安装的插件</h3>
</div>
</v-col>
<v-col cols="12" md="6" lg="3" v-for="extension in extension_data.data">
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" :logo="extension?.logo"
style="margin-bottom: 4px;">
<v-col cols="12" md="6" lg="4" v-for="extension in extension_data.data">
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" style="margin-bottom: 4px;">
<p style="min-height: 130px; max-height: 130px; overflow: none;">{{ extension.desc }}</p>
<div class="d-flex align-center gap-2">
<v-icon>mdi-account</v-icon>
<span>{{ extension.author }}</span>
<v-spacer></v-spacer>
<div v-if="!extension.reserved">
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="openExtensionConfig(extension.name)">配置</v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="updateExtension(extension.name)">更新</v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="uninstallExtension(extension.name)">卸载</v-btn>
<v-btn variant="plain" @click="openExtensionConfig(extension.name)">配置</v-btn>
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
</div>
<!-- <span v-else>保留插件</span> -->
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-if="extension.activated"
@click="pluginOff(extension)"></v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-else
@click="pluginOn(extension)">启用</v-btn>
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
<v-btn variant="plain" v-else @click="pluginOn(extension)"></v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
<div style="display: flex; align-items: center;">
<h3>🧩 插件市场</h3>
<small style="margin-left: 16px;">如无法显示请打开 <a
href="https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json">链接</a> 复制想安装插件对应的 `repo`
链接然后点击右下角 + 号安装或打开链接下载压缩包安装</small>
</div>
<h3>🧩 插件市场</h3>
</div>
</v-col>
<v-col cols="12" md="12" v-if="announcement">
<v-banner color="success" lines="one" :text="announcement" :stacked="false" >
</v-banner>
</v-col>
<v-col cols="12" md="6" lg="3" v-for="plugin in pluginMarketData">
<v-col cols="12" md="6" lg="4" v-for="plugin in pluginMarketData">
<ExtensionCard :key="plugin.name" :title="plugin.name" :link="plugin.repo" style="margin-bottom: 4px;">
<p style="min-height: 130px; max-height: 130px; overflow: hidden;">{{ plugin.desc }}</p>
<div class="d-flex align-center gap-2">
<v-icon>mdi-account</v-icon>
<span>{{ plugin.author }}</span>
<v-spacer></v-spacer>
<v-btn v-if="!plugin.installed" class="text-none mr-2" size="small" text="Read" variant="flat" border
<v-btn v-if="!plugin.installed" variant="plain"
@click="extension_url = plugin.repo; newExtension()">安装</v-btn>
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border disabled>已安装</v-btn>
<v-btn v-else variant="plain" disabled>已安装</v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col style="margin-bottom: 16px;" cols="12" md="12">
<small><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
</v-col>
</v-row>
<v-dialog v-model="configDialog" width="1000">
<v-dialog v-model="configDialog" width="750">
<template v-slot:activator="{ props }">
</template>
<v-card>
@@ -88,9 +65,7 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-container>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata"
:iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<p v-else>这个插件没有配置</p>
<ConfigDetailCard :config="extension_config"></ConfigDetailCard>
</v-container>
</v-card-text>
<v-card-actions>
@@ -197,9 +172,9 @@ export default {
name: 'ExtensionPage',
components: {
ExtensionCard,
ConfigDetailCard,
WaitingForRestart,
ConsoleDisplayer,
AstrBotConfig
ConsoleDisplayer
},
data() {
return {
@@ -214,10 +189,7 @@ export default {
snack_success: "success",
loading_: false,
configDialog: false,
extension_config: {
"metadata": {},
"config": {}
},
extension_config: {},
upload_file: null,
pluginMarketData: {},
loadingDialog: {
@@ -225,19 +197,12 @@ export default {
title: "加载中...",
statusCode: 0, // 0: loading, 1: success, 2: error,
result: ""
},
announcement: ""
}
}
},
mounted() {
this.getExtensions();
this.fetchPluginCollection();
axios.get('https://api.soulter.top/astrbot-announcement-plugin-market').then((res) => {
let data = res.data.data;
this.announcement = data.text;
});
},
methods: {
toast(message, success) {
@@ -399,7 +364,7 @@ export default {
openExtensionConfig(extension_name) {
this.curr_namespace = extension_name;
this.configDialog = true;
axios.get('/api/config/get?plugin_name=' + extension_name).then((res) => {
axios.get('/api/config/get?namespace=' + extension_name).then((res) => {
this.extension_config = res.data.data;
console.log(this.extension_config);
}).catch((err) => {
@@ -407,7 +372,10 @@ export default {
});
},
updateConfig() {
axios.post('/api/config/plugin/update?plugin_name=' + this.curr_namespace, this.extension_config.config).then((res) => {
axios.post('/api/config/plugin/update', {
"config": this.extension_config,
"namespace": this.curr_namespace
}).then((res) => {
if (res.data.status === "ok") {
this.toast(res.data.message, "success");
this.$refs.wfr.check();
@@ -443,23 +411,11 @@ export default {
}
for (let i = 0; i < this.pluginMarketData.length; i++) {
for (let j = 0; j < this.extension_data.data.length; j++) {
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo || this.pluginMarketData[i].name === this.extension_data.data[j].name) {
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo) {
this.pluginMarketData[i].installed = true;
}
}
}
// 将已安装的插件移动到最后面
let installed = [];
let notInstalled = [];
for (let i = 0; i < this.pluginMarketData.length; i++) {
if (this.pluginMarketData[i].installed) {
installed.push(this.pluginMarketData[i]);
} else {
notInstalled.push(this.pluginMarketData[i]);
}
}
this.pluginMarketData = notInstalled.concat(installed);
}
},
}

View File

@@ -1,128 +0,0 @@
import datetime
import uuid
import random
import astrbot.api.star as star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.platform import MessageType
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Plain, Image
from astrbot import logger
from collections import defaultdict
'''
聊天记忆增强
'''
class LongTermMemory:
def __init__(self, config: dict, context: star.Context):
self.config = config
self.context = context
self.session_chats = defaultdict(list)
"""记录群成员的群聊记录"""
try:
self.max_cnt = int(self.config["group_message_max_cnt"])
except BaseException as e:
logger.error(e)
self.max_cnt = 300
self.image_caption = self.config["image_caption"]
self.image_caption_prompt = self.config["image_caption_prompt"]
self.active_reply = self.config["active_reply"]
self.enable_active_reply = self.active_reply.get("enable", False)
self.ar_method = self.active_reply["method"]
self.ar_possibility = self.active_reply["possibility_reply"]
self.ar_prompt = self.active_reply.get("prompt", "")
self.put_history_to_prompt = self.config["put_history_to_prompt"]
async def remove_session(self, event: AstrMessageEvent) -> int:
cnt = 0
if event.unified_msg_origin in self.session_chats:
cnt = len(self.session_chats[event.unified_msg_origin])
del self.session_chats[event.unified_msg_origin]
return cnt
async def get_image_caption(self, image_url: str) -> str:
provider = self.context.get_using_provider()
response = await provider.text_chat(
prompt=self.image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
return response.completion_text
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
if not self.enable_active_reply:
return False
if event.get_message_type() != MessageType.GROUP_MESSAGE:
return False
if event.is_at_or_wake_command:
# if the message is a command, let it pass
return False
match self.ar_method:
case "possibility_reply":
trig = random.random() < self.ar_possibility
return trig
return False
async def handle_message(self, event: AstrMessageEvent):
'''仅支持群聊'''
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
for comp in event.get_messages():
if isinstance(comp, Plain):
final_message += f" {comp.text}"
elif isinstance(comp, Image):
# image_urls.append(comp.url if comp.url else comp.file)
if self.image_caption:
try:
caption = await self.get_image_caption(
comp.url if comp.url else comp.file
)
final_message += f" [Image: {caption}]"
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
else:
final_message += " [Image]"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
'''当触发 LLM 请求前,调用此方法修改 req'''
if event.unified_msg_origin not in self.session_chats:
return
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
if self.put_history_to_prompt:
prompt = req.prompt
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
req.contexts = [] # 清空上下文当使用了群聊增强所有聊天记录都在一个prompt中。
else:
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
req.system_prompt += chats_str
if self.image_caption:
req.system_prompt += (
"The images sent by the members are displayed in text form above."
)
async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats:
return
if event.get_result() and event.get_result().is_llm_result():
final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)

View File

@@ -1,25 +1,17 @@
import aiohttp
import datetime
import builtins
import json
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import sp
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
from astrbot.api.provider import Personality, ProviderRequest
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.config.default import VERSION
from .long_term_memory import LongTermMemory
from astrbot.core import logger
from typing import Union
@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0")
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
@@ -28,12 +20,7 @@ class Main(star.Star):
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.ltm = None
if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']:
try:
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
except BaseException as e:
logger.error(f"聊天增强 err: {e}")
self.kdb_enabled = False
async def _query_astrbot_notice(self):
try:
@@ -45,7 +32,6 @@ class Main(star.Star):
@filter.command("help")
async def help(self, event: AstrMessageEvent):
'''查看帮助'''
notice = ""
try:
notice = await self._query_astrbot_notice()
@@ -55,37 +41,31 @@ class Main(star.Star):
dashboard_version = await get_dashboard_version()
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
AstrBot 指令:
已注册的 AstrBot 内置指令:
[System]
/plugin: 查看插件、插件帮助
/t2i: 开文本转图片
/sid: 获取会话 ID
/op <admin_id>: 授权管理员(op)
/deop <admin_id>: 取消管理员(op)
/wl <sid>: 添加白名单(op)
/dwl <sid>: 删除白名单(op)
/dashboard_update: 更新管理面板(op)
/alter_cmd: 设置指令权限(op)
/plugin: 查看注册的插件、插件帮助
/t2i: 开启/关闭文本转图片模式
/sid: 获取当前会话 ID
/op <admin_id>: 授权管理员
/deop <admin_id>: 取消管理员
/wl <sid>: 添加会话白名单
/dwl <sid>: 删除会话白名单
/dashboard_update: 更新管理面板
[大模型]
/provider: 大模型提供商
/model: 模型列表
/ls: 对话列表
/new: 创建新对
/switch: 切换对话
/rename: 重命名对话
/del: 删除当前会话对话(op)
/reset: 重置 LLM 会话(op)
/history: 当前对话的对话记录
/persona: 人格情景(op)
/tool ls: 函数工具
/key: API Key(op)
/provider: 查看、切换大模型提供商
/model: 查看、切换提供商模型列表
/key: 查看、切换 API Key
/reset: 重置 LLM 会
/history: 获取会话历史记录
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
[其他]
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除会话的变量。
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除当前会话的变量。
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@@ -102,7 +82,7 @@ AstrBot 指令:
active = " (启用)" if tool.active else "(停用)"
msg += f"- {tool.name}: {tool.description} {active}\n"
msg += "\n使用 /tool on/off <工具名> 激活或者停用函数工具。/tool off_all 停用所有函数工具。"
msg += "\n使用 /tool on/off <工具名> 激活或者停用工具。"
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@tool.command("on")
@@ -118,13 +98,6 @@ AstrBot 指令:
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
else:
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
@tool.command("off_all")
async def tool_all_off(self, event: AstrMessageEvent):
tm = self.context.get_llm_tool_manager()
for tool in tm.func_list:
self.context.deactivate_llm_tool(tool.name)
event.set_result(MessageEventResult().message(f"停用所有工具成功。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
@@ -135,7 +108,7 @@ AstrBot 指令:
if plugin_list_info.strip() == "":
plugin_list_info = "没有加载任何插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助和加载的指令\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
else:
if oper1 == "off":
@@ -158,34 +131,10 @@ AstrBot 指令:
plugin = self.context.get_registered_star(oper1)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
return
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "帮助信息: 未提供"
help_msg += f"\n\n作者: {plugin.author}\n版本: {plugin.version}"
command_handlers = []
command_names = []
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
if handler.handler_module_path != plugin.module_path:
continue
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
command_handlers.append(handler)
command_names.append(filter_.command_name)
break
elif isinstance(filter_, CommandGroupFilter):
command_handlers.append(handler)
command_names.append(filter_.group_name)
if len(command_handlers) > 0:
help_msg += "\n\n指令列表:\n"
for i in range(len(command_handlers)):
help_msg += f"{command_names[i]}: {command_handlers[i].desc}\n"
help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。"
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
@@ -245,10 +194,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
async def provider(self, event: AstrMessageEvent, idx: int = None):
'''查看或者切换 LLM Provider'''
if not self.context.get_using_provider():
event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
@@ -271,39 +216,13 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
await self.context.conversation_manager.update_conversation(
message.unified_msg_origin, cid, []
)
ret = "清除会话 LLM 聊天历史成功。"
if self.ltm:
cnt = await self.ltm.remove_session(event=message)
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
message.set_result(MessageEventResult().message(ret))
await self.context.get_using_provider().forget(message.session_id)
message.set_result(MessageEventResult().message("重置成功"))
@filter.command("model")
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if idx_or_name is None:
models = []
try:
@@ -344,29 +263,14 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
'''查看对话记录'''
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
size_per_page = 6
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not session_curr_cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
contexts, total_pages = await self.context.conversation_manager.get_human_readable_context(
message.unified_msg_origin, session_curr_cid, page, size_per_page
)
size_per_page = 3
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
history = ""
for context in contexts:
if len(context) > 150:
context = context[:150] + "..."
history += f"{context}\n"
ret = f"""当前对话历史记录:
ret = f"""历史记录:
{history}
{page} 页 | 共 {total_pages}
@@ -375,95 +279,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1):
'''查看对话列表'''
size_per_page = 6
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
total_pages = len(conversations) // size_per_page
if len(conversations) % size_per_page != 0:
total_pages += 1
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
ret = "对话列表:\n---\n"
global_index = (page - 1) * size_per_page + 1
_titles = {}
for conv in conversations:
persona_id = conv.persona_id
if not persona_id and not persona_id == "[%None]":
persona_id = self.context.provider_manager.selected_default_persona['name']
title = conv.title if conv.title else "新对话"
_titles[conv.cid] = title
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
global_index += 1
ret += "---\n"
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if curr_cid:
ret += f"\n当前对话: {_titles[curr_cid]}({curr_cid[:4]})"
else:
ret += "\n当前对话: 无"
unique_session = self.context.get_config()['platform_settings']['unique_session']
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
ret += "\n会话隔离粒度: 群聊"
ret += f"\n{page} 页 | 共 {total_pages}"
ret += "\n*输入 /ls 2 跳转到第 2 页"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent):
'''创建新对话'''
cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin)
message.set_result(MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"))
@filter.command("switch")
async def switch_conv(self, message: AstrMessageEvent, index: int):
'''通过 /ls 前面的序号切换对话'''
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
if index > len(conversations) or index < 1:
message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看"))
else:
conversation = conversations[index-1]
title = conversation.title if conversation.title else "新对话"
await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid)
message.set_result(MessageEventResult().message(f"切换到对话: {title}({conversation.cid[:4]})。"))
@filter.command("rename")
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
'''重命名对话'''
await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name)
message.set_result(MessageEventResult().message("重命名对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent):
'''删除当前对话'''
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not session_curr_cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
message.set_result(MessageEventResult().message("删除当前对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
if index is None:
keys_data = self.context.get_using_provider().get_keys()
@@ -490,35 +308,23 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
MessageEventResult().message("切换 Key 未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换 Key 成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
l = message.message_str.split(" ")
curr_persona_name = ""
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
curr_cid_title = ""
if cid:
conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid)
if not conversation.persona_id and not conversation.persona_id == "[%None]":
curr_persona_name = self.context.provider_manager.selected_default_persona['name']
else:
curr_persona_name = conversation.persona_id
curr_cid_title = conversation.title if conversation.title else "新对话"
curr_cid_title += f"({cid[:4]})"
if self.context.get_using_provider().curr_personality:
curr_persona_name = self.context.get_using_provider().curr_personality['name']
if len(l) == 1:
message.set_result(
MessageEventResult().message(f"""[Persona]
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
- 人格情景列表: `/persona list`
- 设置人格情景: `/persona 人格`
- 人格情景详细信息: `/persona view 人格`
- 取消人格: `/persona unset`
- 人格情景详细信息: `/persona view 人格`
默认人格情景: {self.context.provider_manager.selected_default_persona['name']}
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
当前人格情景: {curr_persona_name}
配置人格情景请前往管理面板-配置页
""").use_t2i(False))
@@ -542,22 +348,16 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
elif l[1] == "unset":
if not cid:
message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。"))
return
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]")
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
ps = "".join(l[1:]).strip()
if persona := next(builtins.filter(
lambda persona: persona['name'] == ps,
self.context.provider_manager.personas
), None):
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps)
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
self.context.get_using_provider().curr_personality = persona
message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
else:
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
@@ -566,6 +366,31 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await download_dashboard()
yield event.plain_result("管理面板更新完成。")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
# if provider.curr_personality['prompt']:
# req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
@@ -603,189 +428,32 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await platform.logout()
yield event.plain_result("已登出 gewechat")
return
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''群聊记忆增强'''
if self.ltm:
need_active = await self.ltm.need_active_reply(event)
group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable']
if group_icl_enable:
'''记录对话'''
try:
await self.ltm.handle_message(event)
except BaseException as e:
logger.error(e)
if need_active:
'''主动回复'''
provider = self.context.get_using_provider()
if not provider:
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
return
try:
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin)
if not session_curr_cid:
logger.error("当前未处于对话状态,无法主动回复,请使用 /switch 切换或者 /new 创建。")
return
conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid
)
history = json.loads(conv.history)
prompt = self.ltm.ar_prompt
if not prompt:
prompt = event.message_str
yield event.request_llm(
prompt=prompt,
func_tool_manager=self.context.get_llm_tool_manager(),
session_id=event.session_id,
contexts=history if history else []
)
except BaseException as e:
logger.error(f"主动回复失败: {e}")
@filter.command_group("kdb")
def kdb(self):
pass
@kdb.command("on")
async def on_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = True
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if not curr_kdb_name:
yield event.plain_result("未载入任何知识库")
else:
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
@kdb.command("off")
async def off_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = False
yield event.plain_result("知识库已关闭")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if req.conversation:
persona_id = req.conversation.persona_id
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
persona_id = self.context.provider_manager.selected_default_persona['name']
persona = next(builtins.filter(
lambda persona: persona['name'] == persona_id,
self.context.provider_manager.personas
), None)
if persona:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
if self.ltm:
try:
await self.ltm.on_req_llm(event, req)
except BaseException as e:
logger.error(f"ltm: {e}")
@filter.after_message_sent()
async def after_llm_req(self, event: AstrMessageEvent):
'''在 LLM 请求后记录对话'''
if self.ltm:
try:
await self.ltm.after_req_llm(event)
except BaseException as e:
logger.error(f"ltm: {e}")
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd")
async def alter_cmd(self, event: AstrMessageEvent):
# token = event.message_str.split(" ")
token = self.parse_commands(event.message_str)
if token.len < 2:
yield event.plain_result("可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令")
return
cmd_name = token.get(1)
cmd_type = token.get(2)
if cmd_type not in ["admin", "member"]:
yield event.plain_result("指令类型错误,可选类型有 admin, member")
return
# 查找指令
found_command = None
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.command_name == cmd_name:
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if cmd_name == filter_.group_name:
found_command = handler
break
if not found_command:
yield event.plain_result("未找到该指令")
return
found_plugin = star_map[found_command.handler_module_path]
alter_cmd_cfg = sp.get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
cfg = plugin_.get(found_command.handler_name, {})
cfg["permission"] = cmd_type
plugin_[found_command.handler_name] = cfg
alter_cmd_cfg[found_plugin.name] = plugin_
sp.put("alter_cmd", alter_cmd_cfg)
# 注入权限过滤器
found_permission_filter = False
for filter_ in found_command.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
filter_.permission_type = filter.PermissionType.ADMIN
else:
filter_.permission_type = filter.PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
found_command.event_filters.insert(0, PermissionTypeFilter(filter.PermissionType.ADMIN if cmd_type == "admin" else filter.PermissionType.MEMBER))
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
# @filter.command_group("kdb")
# def kdb(self):
# pass
# @kdb.command("on")
# async def on_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = True
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if not curr_kdb_name:
# yield event.plain_result("未载入任何知识库")
# else:
# yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
# @kdb.command("off")
# async def off_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = False
# yield event.plain_result("知识库已关闭")
# @filter.on_llm_request()
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if self.kdb_enabled and curr_kdb_name:
# mgr = self.context.knowledge_db_manager
# results = await mgr.retrive_records(curr_kdb_name, req.prompt)
# if results:
# req.system_prompt += "\nHere are documents that related to user's query: \n"
# for result in results:
# req.system_prompt += f"- {result}\n"
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if self.kdb_enabled and curr_kdb_name:
mgr = self.context.knowledge_db_manager
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
if results:
req.system_prompt += "\nHere are documents that related to user's query: \n"
for result in results:
req.system_prompt += f"- {result}\n"

View File

@@ -358,7 +358,7 @@ class Main(star.Star):
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
break
@@ -393,4 +393,4 @@ class Main(star.Star):
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()
await container.delete()

View File

@@ -8,8 +8,7 @@ from typing import List
class Google(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.proxy = os.environ.get("https_proxy")
print(f"Google Search using proxy: {self.proxy}")
self.proxy = os.environ.get("HTTPS_PROXY")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = []

View File

@@ -22,8 +22,6 @@ class Main(star.Star):
self.sogo_search = Sogo()
self.google = Google()
self.websearch_link = self.context.get_config()['provider_settings'].get('web_search_link', False)
async def initialize(self):
websearch = self.context.get_config()['provider_settings']['web_search']
if websearch:
@@ -111,17 +109,8 @@ class Main(star.Star):
except BaseException:
site_result = ""
site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result
header = f"{idx}. {i.title} "
if self.websearch_link and i.url:
header += i.url
ret += f"{header}\n{i.snippet}\n{site_result}\n\n"
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
if self.websearch_link:
ret += "针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
return ret

View File

@@ -1,4 +1,5 @@
pydantic~=2.10.3
vchat
aiohttp
openai
qq-botpy

33
stt.py Normal file
View File

@@ -0,0 +1,33 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

View File

@@ -5,7 +5,6 @@ import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import CONFIG_METADATA_2
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType