Compare commits
100 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a59e894e1 | ||
|
|
7a8d65d37d | ||
|
|
fb10faa2dc | ||
|
|
23129a9ba2 | ||
|
|
e4df0e83ed | ||
|
|
7f791e730b | ||
|
|
4a22664b8e | ||
|
|
f7e296b349 | ||
|
|
712d4acaaa | ||
|
|
74a5c01f21 | ||
|
|
3ba8724d77 | ||
|
|
6313a7d8a9 | ||
|
|
432a3f520c | ||
|
|
191b3e42d4 | ||
|
|
a27f05fcb4 | ||
|
|
2f33e0b873 | ||
|
|
f0359467f1 | ||
|
|
d1db8cf2c8 | ||
|
|
b1985ed2ce | ||
|
|
140ddc70e6 | ||
|
|
d7fd616470 | ||
|
|
3ccbef141e | ||
|
|
e92fbb0443 | ||
|
|
bd270aed68 | ||
|
|
28d7864393 | ||
|
|
b5d8173ee3 | ||
|
|
17d62a9af7 | ||
|
|
d89fb863ed | ||
|
|
a21ad77820 | ||
|
|
f86c8e8cab | ||
|
|
cb12cbdd3d | ||
|
|
6661fa996c | ||
|
|
c19bca798b | ||
|
|
8f98b411db | ||
|
|
a8aa03847e | ||
|
|
1bfd747cc6 | ||
|
|
ae06d945a7 | ||
|
|
9f41d5f34d | ||
|
|
ef61c52908 | ||
|
|
d8842ef274 | ||
|
|
c88fdaf353 | ||
|
|
af295da871 | ||
|
|
083235a2fe | ||
|
|
2a3a5f7eb2 | ||
|
|
77c48f280f | ||
|
|
0ee1eb2f9f | ||
|
|
c2b20365bb | ||
|
|
cfdc7e4452 | ||
|
|
2363f61aa9 | ||
|
|
557ac6f9fa | ||
|
|
a49b871cf9 | ||
|
|
a0d6b3efba | ||
|
|
6cabf07bc0 | ||
|
|
a15444ee8c | ||
|
|
ceb5f5669e | ||
|
|
25b75e05e4 | ||
|
|
4d214bb5c1 | ||
|
|
7cbaed8c6c | ||
|
|
2915fdf665 | ||
|
|
a66c385b08 | ||
|
|
4dace7c5d8 | ||
|
|
8ebf087dbf | ||
|
|
2fa8bda5bb | ||
|
|
7cfbc4ab8f | ||
|
|
0f692b1608 | ||
|
|
2cc1eb1abc | ||
|
|
90dbcbb4e2 | ||
|
|
66503d58be | ||
|
|
8e10f0ce2b | ||
|
|
c44f085b47 | ||
|
|
a35f36eeaf | ||
|
|
14564c392a | ||
|
|
28a87351f1 | ||
|
|
dcd7dcbbdf | ||
|
|
1538759ba7 | ||
|
|
ec5d71d0e1 | ||
|
|
d121d08d05 | ||
|
|
be08f4a558 | ||
|
|
4df8606ab6 | ||
|
|
71442d26ec | ||
|
|
4f5528869c | ||
|
|
f16feff17b | ||
|
|
d8aae538cd | ||
|
|
31670e75e5 | ||
|
|
ed6011a2be | ||
|
|
cdded38ade | ||
|
|
f536f24833 | ||
|
|
646b18d910 | ||
|
|
e24225c828 | ||
|
|
50a296de20 | ||
|
|
c79e38e044 | ||
|
|
dae745d925 | ||
|
|
791db65526 | ||
|
|
02e2e617f5 | ||
|
|
bfc8024119 | ||
|
|
f26cf6ed6f | ||
|
|
f2be55bd8e | ||
|
|
d241dd17ca | ||
|
|
cecafdfe6c | ||
|
|
6fecfd1a0e |
@@ -53,7 +53,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
@@ -125,7 +125,6 @@ uvx astrbot init
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| 微信个人号 | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
@@ -246,11 +245,5 @@ _✨ WebUI ✨_
|
||||

|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||
3. Please ensure compliance with local laws and regulations when using this project.
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
## ✨ 主な機能
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
||||
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "3.5.8"
|
||||
__version__ = "3.5.23"
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.21"
|
||||
VERSION = "3.5.23"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -157,15 +157,6 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"微信个人号(Gewechat)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
"微信个人号(WeChatPadPro)": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
@@ -318,8 +309,7 @@ CONFIG_METADATA_2 = {
|
||||
"id": {
|
||||
"description": "机器人名称",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
|
||||
"hint": "机器人名称",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -370,7 +360,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"discord_token": {
|
||||
"description": "Discord Bot Token",
|
||||
@@ -486,13 +475,11 @@ CONFIG_METADATA_2 = {
|
||||
"regex": {
|
||||
"description": "正则表达式",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||
},
|
||||
"content_cleanup_rule": {
|
||||
"description": "过滤分段后的内容",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
|
||||
},
|
||||
},
|
||||
@@ -515,7 +502,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
@@ -545,7 +531,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
},
|
||||
},
|
||||
@@ -605,6 +590,7 @@ CONFIG_METADATA_2 = {
|
||||
"config_template": {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
"provider": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -617,6 +603,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -630,6 +617,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -642,6 +630,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "claude",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -655,6 +644,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Ollama": {
|
||||
"id": "ollama_default",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -666,6 +656,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -677,6 +668,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -689,6 +681,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -714,6 +707,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"provider": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -726,6 +720,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
"provider": "302ai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -738,6 +733,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"硅基流动": {
|
||||
"id": "siliconflow",
|
||||
"provider": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -750,6 +746,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"id": "ppio",
|
||||
"provider": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -762,6 +759,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -774,6 +772,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -786,6 +785,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
"provider": "dify",
|
||||
"type": "dify",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -799,6 +799,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -815,6 +816,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"provider": "fastgpt",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -824,6 +826,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Whisper(API)": {
|
||||
"id": "whisper",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_api",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
@@ -833,15 +836,17 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"id": "whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(本地加载)": {
|
||||
"sensevoice_hint": "(不用修改我)",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "sensevoice",
|
||||
@@ -851,6 +856,7 @@ CONFIG_METADATA_2 = {
|
||||
"OpenAI TTS(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"provider": "openai",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
@@ -862,6 +868,7 @@ CONFIG_METADATA_2 = {
|
||||
"Edge TTS": {
|
||||
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"provider": "microsoft",
|
||||
"type": "edge_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
@@ -871,6 +878,7 @@ CONFIG_METADATA_2 = {
|
||||
"GSV TTS(本地加载)": {
|
||||
"id": "gsv_tts",
|
||||
"enable": False,
|
||||
"provider": "gpt_sovits",
|
||||
"type": "gsv_tts_selfhost",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:9880",
|
||||
@@ -902,6 +910,7 @@ CONFIG_METADATA_2 = {
|
||||
"GSVI TTS(API)": {
|
||||
"id": "gsvi_tts",
|
||||
"type": "gsvi_tts_api",
|
||||
"provider": "gpt_sovits_inference",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:5000",
|
||||
"character": "",
|
||||
@@ -911,6 +920,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"FishAudio TTS(API)": {
|
||||
"id": "fishaudio_tts",
|
||||
"provider": "fishaudio",
|
||||
"type": "fishaudio_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
@@ -921,6 +931,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"id": "dashscope_tts",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
@@ -932,6 +943,7 @@ CONFIG_METADATA_2 = {
|
||||
"Azure TTS": {
|
||||
"id": "azure_tts",
|
||||
"type": "azure_tts",
|
||||
"provider": "azure",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": True,
|
||||
"azure_tts_voice": "zh-CN-YunxiaNeural",
|
||||
@@ -945,6 +957,7 @@ CONFIG_METADATA_2 = {
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
"type": "minimax_tts_api",
|
||||
"provider": "minimax",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
@@ -966,6 +979,7 @@ CONFIG_METADATA_2 = {
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
"type": "volcengine_tts",
|
||||
"provider": "volcengine",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
@@ -979,6 +993,7 @@ CONFIG_METADATA_2 = {
|
||||
"Gemini TTS": {
|
||||
"id": "gemini_tts",
|
||||
"type": "gemini_tts",
|
||||
"provider": "google",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"gemini_tts_api_key": "",
|
||||
@@ -991,17 +1006,19 @@ CONFIG_METADATA_2 = {
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
"provider": "openai",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1536,
|
||||
"embedding_dimensions": 1024,
|
||||
"timeout": 20,
|
||||
},
|
||||
"Gemini Embedding": {
|
||||
"id": "gemini_embedding",
|
||||
"type": "gemini_embedding",
|
||||
"provider": "google",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
@@ -1012,17 +1029,19 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
},
|
||||
"gpt_weights_path": {
|
||||
"description": "GPT模型文件路径",
|
||||
"type": "string",
|
||||
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"sovits_weights_path": {
|
||||
"description": "SoVITS模型文件路径",
|
||||
"type": "string",
|
||||
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_default_parms": {
|
||||
"description": "GPT_SoVITS默认参数",
|
||||
@@ -1033,13 +1052,11 @@ CONFIG_METADATA_2 = {
|
||||
"description": "参考音频文件路径",
|
||||
"type": "string",
|
||||
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_prompt_text": {
|
||||
"description": "参考音频文本",
|
||||
"type": "string",
|
||||
"hint": "必填!请填写参考音频讲述的文本",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_prompt_lang": {
|
||||
"description": "参考音频文本语言",
|
||||
@@ -1266,19 +1283,16 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_native_coderunner": {
|
||||
"description": "启用原生代码执行器",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_url_context": {
|
||||
"description": "启用URL上下文功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
@@ -1462,7 +1476,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "部署SenseVoice",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"is_emotion": {
|
||||
"description": "情绪识别",
|
||||
@@ -1477,18 +1490,10 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {
|
||||
"description": "工作流固定输入变量",
|
||||
"type": "object",
|
||||
"obvious_hint": True,
|
||||
"items": {},
|
||||
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
||||
"invisible": True,
|
||||
},
|
||||
# "fastgpt_app_type": {
|
||||
# "description": "应用类型",
|
||||
# "type": "string",
|
||||
# "hint": "FastGPT 应用的应用类型。",
|
||||
# "options": ["agent", "workflow", "plugin"],
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
"dashscope_app_type": {
|
||||
"description": "应用类型",
|
||||
"type": "string",
|
||||
@@ -1499,7 +1504,6 @@ CONFIG_METADATA_2 = {
|
||||
"dialog-workflow",
|
||||
"task-workflow",
|
||||
],
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"timeout": {
|
||||
"description": "超时时间",
|
||||
@@ -1509,26 +1513,22 @@ CONFIG_METADATA_2 = {
|
||||
"openai-tts-voice": {
|
||||
"description": "voice",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||
},
|
||||
"fishaudio-tts-character": {
|
||||
"description": "character",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
||||
},
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
|
||||
"hint": "模型提供商名字。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商种类",
|
||||
@@ -1543,53 +1543,27 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用",
|
||||
"type": "bool",
|
||||
"hint": "是否启用该模型。未启用的模型将不会被使用。",
|
||||
"hint": "是否启用。",
|
||||
},
|
||||
"key": {
|
||||
"description": "API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。",
|
||||
"hint": "提供商 API Key。",
|
||||
},
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
"description": "基座模型路径",
|
||||
"type": "string",
|
||||
"hint": "基座模型路径。",
|
||||
},
|
||||
"adapter_model_path": {
|
||||
"description": "Adapter 模型路径",
|
||||
"type": "string",
|
||||
"hint": "Adapter 模型路径。如 Lora",
|
||||
},
|
||||
"llmtuner_template": {
|
||||
"description": "template",
|
||||
"type": "string",
|
||||
"hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。",
|
||||
},
|
||||
"finetuning_type": {
|
||||
"description": "微调类型",
|
||||
"type": "string",
|
||||
"hint": "微调类型。如 `lora`",
|
||||
},
|
||||
"quantization_bit": {
|
||||
"description": "量化位数",
|
||||
"type": "int",
|
||||
"hint": "量化位数。如 4",
|
||||
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "文本生成模型",
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {
|
||||
"description": "模型名称",
|
||||
"type": "string",
|
||||
"hint": "大语言模型的名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat 等。",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "模型最大输出长度(tokens)",
|
||||
@@ -1636,7 +1610,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"separate_provider": {
|
||||
"description": "提供商会话隔离",
|
||||
@@ -1656,13 +1629,11 @@ CONFIG_METADATA_2 = {
|
||||
"web_search": {
|
||||
"description": "启用网页搜索",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
},
|
||||
"web_search_link": {
|
||||
"description": "网页搜索引用链接",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||
},
|
||||
"display_reasoning_text": {
|
||||
@@ -1673,13 +1644,11 @@ CONFIG_METADATA_2 = {
|
||||
"identifier": {
|
||||
"description": "启动识别群员",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"description": "启用日期时间系统提示",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||
},
|
||||
"default_personality": {
|
||||
@@ -1736,7 +1705,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "人格名称",
|
||||
"type": "string",
|
||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"prompt": {
|
||||
"description": "设定(系统提示词)",
|
||||
@@ -1748,14 +1716,12 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1767,7 +1733,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用语音转文本(STT)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID",
|
||||
@@ -1784,7 +1749,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用文本转语音(TTS)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID",
|
||||
@@ -1795,7 +1759,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用语音和文字双输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"use_file_service": {
|
||||
"description": "使用文件服务提供 TTS 语音文件",
|
||||
@@ -1811,25 +1774,21 @@ CONFIG_METADATA_2 = {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious_hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "群聊图像转述(需模型支持)",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"description": "图像转述提供商 ID",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
@@ -1843,14 +1802,12 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用主动回复",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
|
||||
},
|
||||
"whitelist": {
|
||||
"description": "主动回复白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
|
||||
},
|
||||
"method": {
|
||||
@@ -1862,13 +1819,11 @@ CONFIG_METADATA_2 = {
|
||||
"possibility_reply": {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"obvious_hint": True,
|
||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
@@ -1884,7 +1839,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "机器人唤醒前缀",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
||||
},
|
||||
"t2i": {
|
||||
@@ -1911,13 +1865,11 @@ CONFIG_METADATA_2 = {
|
||||
"timezone": {
|
||||
"description": "时区",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||
},
|
||||
"callback_api_base": {
|
||||
"description": "对外可达的回调接口地址",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
||||
},
|
||||
"log_level": {
|
||||
@@ -1965,90 +1917,3 @@ DEFAULT_VALUE_MAP = {
|
||||
"list": [],
|
||||
"object": {},
|
||||
}
|
||||
|
||||
|
||||
# "project_atri": {
|
||||
# "description": "Project ATRI 配置",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# "long_term_memory": {
|
||||
# "description": "长期记忆",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# "summary_threshold_cnt": {
|
||||
# "description": "摘要阈值",
|
||||
# "type": "int",
|
||||
# "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。",
|
||||
# },
|
||||
# "embedding_provider_id": {
|
||||
# "description": "Embedding provider ID",
|
||||
# "type": "string",
|
||||
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "summarize_provider_id": {
|
||||
# "description": "Summary provider ID",
|
||||
# "type": "string",
|
||||
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
# "active_message": {
|
||||
# "description": "主动消息",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# },
|
||||
# },
|
||||
# "vision": {
|
||||
# "description": "视觉理解",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# "provider_id_or_ofa_model_path": {
|
||||
# "description": "提供商 ID 或 OFA 模型路径",
|
||||
# "type": "string",
|
||||
# "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。",
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
# "split_response": {
|
||||
# "description": "是否分割回复",
|
||||
# "type": "bool",
|
||||
# "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。",
|
||||
# },
|
||||
# "persona": {
|
||||
# "description": "人格",
|
||||
# "type": "string",
|
||||
# "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "chat_provider_id": {
|
||||
# "description": "Chat provider ID",
|
||||
# "type": "string",
|
||||
# "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "chat_base_model_path": {
|
||||
# "description": "用于聊天的基座模型路径",
|
||||
# "type": "string",
|
||||
# "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "chat_adapter_model_path": {
|
||||
# "description": "用于聊天的 Lora 模型路径",
|
||||
# "type": "string",
|
||||
# "hint": "Lora 模型路径。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "quantization_bit": {
|
||||
# "description": "量化位数",
|
||||
# "type": "int",
|
||||
# "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
|
||||
@@ -96,8 +96,6 @@ class LogBroker:
|
||||
Queue: 订阅者的队列, 可用于接收日志消息
|
||||
"""
|
||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||
for log in self.log_cache:
|
||||
q.put_nowait(log)
|
||||
self.subscribers.append(q)
|
||||
return q
|
||||
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
EventResultType,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
@@ -29,6 +31,7 @@ STAGES_ORDER = [
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
|
||||
@@ -2,29 +2,30 @@
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import copy
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from ..stage import Stage
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -72,6 +73,12 @@ class LLMRequestSubStage(Stage):
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
@@ -166,6 +173,9 @@ class LLMRequestSubStage(Stage):
|
||||
event=event,
|
||||
pipeline_ctx=self.ctx,
|
||||
)
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
||||
)
|
||||
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||
|
||||
async def requesting():
|
||||
@@ -221,7 +231,7 @@ class LLMRequestSubStage(Stage):
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -177,6 +178,8 @@ class RespondStage(Stage):
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
@@ -185,7 +188,6 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
|
||||
@@ -3,11 +3,12 @@ import time
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
@@ -176,10 +177,12 @@ class ResultDecorateStage(Stage):
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
|
||||
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
event.stop_event()
|
||||
@@ -1,13 +1,16 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -166,6 +169,11 @@ class WakingCheckStage(Stage):
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
event, activated_handlers
|
||||
)
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
||||
|
||||
|
||||
@@ -227,7 +227,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp, gewechat。
|
||||
Fallback仅支持 aiocqhttp。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
@@ -419,7 +419,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
适配情况:
|
||||
|
||||
- gewechat
|
||||
- aiocqhttp(OneBotv11)
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -58,10 +58,6 @@ class PlatformManager:
|
||||
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
|
||||
@@ -272,8 +272,14 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
|
||||
reply_event_data["post_type"] = "message"
|
||||
new_event = Event.from_payload(reply_event_data)
|
||||
if not new_event:
|
||||
logger.error(
|
||||
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
|
||||
)
|
||||
continue
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
new_event, get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
|
||||
@@ -1,812 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SimpleGewechatClient:
|
||||
"""针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
nickname: str,
|
||||
host: str,
|
||||
port: int,
|
||||
event_queue: asyncio.Queue,
|
||||
):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith("/"):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
|
||||
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_token>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
self.userrealnames = {}
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self.staged_files = {}
|
||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob["data"]
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {"X-GEWE-TOKEN": self.token}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
if "TypeName" in data:
|
||||
type_name = data["TypeName"]
|
||||
elif "type_name" in data:
|
||||
type_name = data["type_name"]
|
||||
else:
|
||||
raise Exception("无法识别的消息类型")
|
||||
|
||||
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||
if type_name == "ModContacts":
|
||||
logger.info("gewechat下发:ModContacts消息通知。")
|
||||
return
|
||||
if type_name == "DelContacts":
|
||||
logger.info("gewechat下发:DelContacts消息通知。")
|
||||
return
|
||||
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
d = None
|
||||
if "Data" in data:
|
||||
d = data["Data"]
|
||||
elif "data" in data:
|
||||
d = data["data"]
|
||||
|
||||
if not d:
|
||||
logger.warning(f"消息不含 data 字段: {data}")
|
||||
return
|
||||
|
||||
if "CreateTime" in d:
|
||||
# 得到系统 UTF+8 的 ts
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
ts = datetime.datetime.now(tz).timestamp()
|
||||
create_time = d["CreateTime"]
|
||||
if create_time < ts - 30:
|
||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
from_user_name = d["FromUserName"]["string"] # 消息来源
|
||||
d["to_wxid"] = from_user_name # 用于发信息
|
||||
|
||||
abm.message_id = str(d.get("MsgId"))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data["Wxid"] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
at_wxids = []
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
at_wxids = re.findall(
|
||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||
msg_source,
|
||||
)
|
||||
|
||||
abm.group_id = from_user_name
|
||||
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
):
|
||||
at_me = True
|
||||
if "在群聊中@了你" in d.get("PushContent", ""):
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||
# if user_id == abm.self_id:
|
||||
# logger.info("忽略自己发送的消息")
|
||||
# return None
|
||||
|
||||
abm.message = []
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
if abm.group_id:
|
||||
if (
|
||||
abm.group_id not in self.userrealnames
|
||||
or user_id not in self.userrealnames[abm.group_id]
|
||||
):
|
||||
# 获取群成员列表,并且缓存
|
||||
if abm.group_id not in self.userrealnames:
|
||||
self.userrealnames[abm.group_id] = {}
|
||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||
if member_list and "memberList" in member_list:
|
||||
for member in member_list["memberList"]:
|
||||
self.userrealnames[abm.group_id][member["wxid"]] = member[
|
||||
"nickName"
|
||||
]
|
||||
if user_id in self.userrealnames[abm.group_id]:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
try:
|
||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||
user_real_name = info["nickName"]
|
||||
except Exception as e:
|
||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||
user_real_name = user_id
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||
for wxid in at_wxids:
|
||||
# 群聊里 At 其他人的列表
|
||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||
abm.message.append(At(qq=wxid, name=_username))
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
|
||||
if user_id == "weixin":
|
||||
# 忽略微信团队消息
|
||||
return
|
||||
|
||||
# 不同消息类型
|
||||
match d["MsgType"]:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid, content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||
case 37: # 好友申请
|
||||
logger.info("消息类型(37):好友申请")
|
||||
case 42: # 名片
|
||||
logger.info("消息类型(42):名片")
|
||||
case 43: # 视频
|
||||
video = Video(file="", cover=content)
|
||||
abm.message.append(video)
|
||||
case 47: # emoji
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
emoji = data_parser.parse_emoji()
|
||||
abm.message.append(emoji)
|
||||
case 48: # 地理位置
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
segments = data_parser.parse_mutil_49()
|
||||
if segments:
|
||||
abm.message.extend(segments)
|
||||
for seg in segments:
|
||||
if isinstance(seg, Plain):
|
||||
abm.message_str += seg.text
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||
logger.info(
|
||||
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||
)
|
||||
|
||||
case _:
|
||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||
abm.raw_message = d
|
||||
|
||||
logger.debug(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def _callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get("testMsg", None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = None
|
||||
try:
|
||||
abm = await self._convert(data)
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。"
|
||||
)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _register_file(self, file_path: str) -> str:
|
||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
Returns:
|
||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||
"""
|
||||
async with self.lock:
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
self.staged_files[file_token] = file_path
|
||||
return file_token
|
||||
|
||||
async def _handle_file(self, file_token):
|
||||
async with self.lock:
|
||||
if file_token not in self.staged_files:
|
||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||
return quart.abort(404)
|
||||
if not os.path.exists(self.staged_files[file_token]):
|
||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||
return quart.abort(404)
|
||||
file_path = self.staged_files[file_token]
|
||||
self.staged_files.pop(file_token, None)
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={"token": self.token, "callbackUrl": self.callback_url},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob["ret"] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(
|
||||
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
|
||||
)
|
||||
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host="0.0.0.0",
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
"""检查 APPID 对应的设备是否在线。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={"appId": appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def logout(self):
|
||||
"""登出 gewechat。"""
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={"appId": self.appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(
|
||||
self.base_url, self.download_base_url, self.token
|
||||
)
|
||||
|
||||
if self.appid:
|
||||
try:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态失败: {e}")
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob["ret"] != 200:
|
||||
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||
if "设备不存在" in error_msg:
|
||||
logger.error(
|
||||
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||
)
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
return await self.login()
|
||||
else:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob["data"]["qrData"]
|
||||
qr_uuid = json_blob["data"]["uuid"]
|
||||
appid = json_blob["data"]["appId"]
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(
|
||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({"uuid": qr_uuid, "appId": appid})
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
|
||||
ret = json_blob["ret"]
|
||||
msg = ""
|
||||
if json_blob["data"] and "msg" in json_blob["data"]:
|
||||
msg = json_blob["data"]["msg"]
|
||||
if ret == 500 and "安全验证码" in msg:
|
||||
logger.warning(
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
if "status" in json_blob["data"]:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||
"""
|
||||
|
||||
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||
"""获取群成员列表。
|
||||
|
||||
Args:
|
||||
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||
|
||||
Returns:
|
||||
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||
"""
|
||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
"""发送纯文本消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
if ats:
|
||||
payload["ats"] = ats
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
"""发送图片消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||
"""发送emoji消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"emojiMd5": emoji_md5,
|
||||
"emojiSize": emoji_size,
|
||||
}
|
||||
|
||||
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||
try:
|
||||
if emoji_md5 != "" and emoji_size != "":
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postEmoji",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(
|
||||
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||
)
|
||||
else:
|
||||
await self.post_image(to_wxid, cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
async def post_video(
|
||||
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"videoUrl": video_url,
|
||||
"thumbUrl": thumb_url,
|
||||
"videoDuration": video_duration,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送视频结果: {json_blob}")
|
||||
|
||||
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||
"""转发视频
|
||||
|
||||
Args:
|
||||
to_wxid (str): 发送给谁
|
||||
cnd_xml (str): 视频消息的cdn信息
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"xml": cnd_xml,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/forwardVideo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"转发视频结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
"""发送语音信息
|
||||
|
||||
Args:
|
||||
voice_url (str): 语音文件的网络链接
|
||||
voice_duration (int): 语音时长,毫秒
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration,
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||
|
||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||
"""发送文件
|
||||
|
||||
Args:
|
||||
to_wxid (string): 微信ID
|
||||
file_url (str): 文件的网络链接
|
||||
file_name (str): 文件名
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"fileUrl": file_url,
|
||||
"fileName": file_name,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送文件结果: {json_blob}")
|
||||
|
||||
async def add_friend(self, v3: str, v4: str, content: str):
|
||||
"""申请添加好友"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"scene": 3,
|
||||
"content": content,
|
||||
"v4": v4,
|
||||
"v3": v3,
|
||||
"option": 2,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/addContacts",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group_member(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def accept_group_invite(self, url: str):
|
||||
"""同意进群"""
|
||||
payload = {"appId": self.appid, "url": url}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/agreeJoinRoom",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def add_group_member_to_friend(
|
||||
self, group_id: str, to_wxid: str, content: str
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
"content": content,
|
||||
"memberWxid": to_wxid,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_user_or_group_info(self, *ids):
|
||||
"""
|
||||
获取用户或群组信息。
|
||||
|
||||
:param ids: 可变数量的 wxid 参数
|
||||
"""
|
||||
|
||||
wxids_str = list(ids)
|
||||
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/getDetailInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_contacts_list(self):
|
||||
"""
|
||||
获取通讯录列表
|
||||
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||
"""
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/fetchContactsList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||
return json_blob
|
||||
@@ -1,55 +0,0 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
class GeweDownloader:
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
"""返回一个可下载的 URL"""
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {"appId": appid, "xml": xml, "type": choice}
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadImage", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
if "fileUrl" in json_blob["data"]:
|
||||
return self.download_base_url + json_blob["data"]["fileUrl"]
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
|
||||
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||
"""下载emoji"""
|
||||
try:
|
||||
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||
|
||||
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadEmojiMd5", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
return json_blob
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download emoji: {e}")
|
||||
@@ -1,264 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
Record,
|
||||
At,
|
||||
File,
|
||||
Video,
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, "rb") as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
elif n_frames == 0:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(
|
||||
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
||||
):
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
# 检查@
|
||||
ats = []
|
||||
ats_names = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, At):
|
||||
ats.append(comp.qq)
|
||||
ats_names.append(comp.name)
|
||||
has_at = False
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text = comp.text
|
||||
payload = {
|
||||
"to_wxid": to_wxid,
|
||||
"content": text,
|
||||
}
|
||||
if not has_at and ats:
|
||||
ats = f"{','.join(ats)}"
|
||||
ats_names = f"@{' @'.join(ats_names)}"
|
||||
text = f"{ats_names} {text}"
|
||||
payload["content"] = text
|
||||
payload["ats"] = ats
|
||||
has_at = True
|
||||
await client.post_text(**payload)
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||
token = await client._register_file(img_path)
|
||||
img_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
if comp.cover != "":
|
||||
await client.forward_video(to_wxid, comp.cover)
|
||||
else:
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
raise ModuleNotFoundError(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
|
||||
video_token = await client._register_file(video_path)
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||
ff.options(command)
|
||||
thumb_token = await client._register_file(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
|
||||
# 创建 FFprobe 实例
|
||||
ffprobe = FFprobe(video_url)
|
||||
# 获取时长字符串
|
||||
duration_str = ffprobe.duration
|
||||
# 处理时长字符串
|
||||
video_duration = float(duration_str.replace(":", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
# 发送视频
|
||||
await client.post_video(
|
||||
to_wxid, video_callback_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时缩略图文件
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
token = await client._register_file(silk_path)
|
||||
record_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
file_path = comp.file
|
||||
file_name = comp.name
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
token = await client._register_file(file_path)
|
||||
file_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_name)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||
await super().send(message)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
# 确定有效的 group_id
|
||||
if group_id is None:
|
||||
group_id = self.get_group_id()
|
||||
|
||||
if not group_id:
|
||||
return None
|
||||
|
||||
res = await self.client.get_group(group_id)
|
||||
data: dict = res["data"]
|
||||
|
||||
if not data["chatroomId"]:
|
||||
return None
|
||||
|
||||
members = [
|
||||
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||
for member in data.get("memberList", [])
|
||||
]
|
||||
|
||||
return Group(
|
||||
group_id=data["chatroomId"],
|
||||
group_name=data.get("nickName"),
|
||||
group_avatar=data.get("smallHeadImgUrl"),
|
||||
group_owner=data.get("chatRoomOwner"),
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
@@ -1,103 +0,0 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||
self.client = None
|
||||
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config["base_url"],
|
||||
self.config["nickname"],
|
||||
self.config["host"],
|
||||
self.config["port"],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
session_id = session.session_id
|
||||
if "#" in session_id:
|
||||
# unique session
|
||||
to_wxid = session_id.split("#")[1]
|
||||
else:
|
||||
to_wxid = session_id
|
||||
|
||||
await GewechatPlatformEvent.send_with_client(
|
||||
message_chain, to_wxid, self.client
|
||||
)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="gewechat",
|
||||
description="基于 gewechat 的 Wechat 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
self.client.shutdown_event.set()
|
||||
try:
|
||||
await self.client.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
await self.client.start_polling()
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss["unique_session"]:
|
||||
message.session_id = message.sender.user_id + "#" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self) -> SimpleGewechatClient:
|
||||
return self.client
|
||||
@@ -1,110 +0,0 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Reply,
|
||||
Plain,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
def __init__(self, data, is_private_chat):
|
||||
self.data = data
|
||||
self.is_private_chat = is_private_chat
|
||||
|
||||
def _format_to_xml(self):
|
||||
return eT.fromstring(self.data)
|
||||
|
||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||
if appmsg_type is None:
|
||||
return
|
||||
|
||||
match appmsg_type.text:
|
||||
case "57":
|
||||
return self.parse_reply()
|
||||
|
||||
def parse_emoji(self) -> Emoji | None:
|
||||
try:
|
||||
emoji_element = self._format_to_xml().find(".//emoji")
|
||||
# 提取 md5 和 len 属性
|
||||
if emoji_element is not None:
|
||||
md5_value = emoji_element.get("md5")
|
||||
emoji_size = emoji_element.get("len")
|
||||
cdnurl = emoji_element.get("cdnurl")
|
||||
|
||||
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||
|
||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||
"""解析引用消息
|
||||
|
||||
Returns:
|
||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||
"""
|
||||
try:
|
||||
replied_id = -1
|
||||
replied_uid = 0
|
||||
replied_nickname = ""
|
||||
replied_content = "" # 被引用者说的内容
|
||||
content = "" # 引用者说的内容
|
||||
|
||||
root = self._format_to_xml()
|
||||
refermsg = root.find(".//refermsg")
|
||||
if refermsg is not None:
|
||||
# 被引用的信息
|
||||
svrid = refermsg.find("svrid")
|
||||
fromusr = refermsg.find("fromusr")
|
||||
displayname = refermsg.find("displayname")
|
||||
refermsg_content = refermsg.find("content")
|
||||
if svrid is not None:
|
||||
replied_id = svrid.text
|
||||
if fromusr is not None:
|
||||
replied_uid = fromusr.text
|
||||
if displayname is not None:
|
||||
replied_nickname = displayname.text
|
||||
if refermsg_content is not None:
|
||||
# 处理引用嵌套,包括嵌套公众号消息
|
||||
if refermsg_content.text.startswith(
|
||||
"<msg>"
|
||||
) or refermsg_content.text.startswith("<?xml"):
|
||||
try:
|
||||
logger.debug("gewechat: Reference message is nested")
|
||||
refer_root = eT.fromstring(refermsg_content.text)
|
||||
img = refer_root.find("img")
|
||||
if img is not None:
|
||||
replied_content = "[图片]"
|
||||
else:
|
||||
app_msg = refer_root.find("appmsg")
|
||||
refermsg_content_title = app_msg.find("title")
|
||||
logger.debug(
|
||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||
)
|
||||
replied_content = refermsg_content_title.text
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: nested failed, {e}")
|
||||
# 处理异常情况
|
||||
replied_content = refermsg_content.text
|
||||
else:
|
||||
replied_content = refermsg_content.text
|
||||
|
||||
# 提取引用者说的内容
|
||||
title = root.find(".//appmsg/title")
|
||||
if title is not None:
|
||||
content = title.text
|
||||
|
||||
reply_seg = Reply(
|
||||
id=replied_id,
|
||||
chain=[Plain(replied_content)],
|
||||
sender_id=replied_uid,
|
||||
sender_nickname=replied_nickname,
|
||||
message_str=replied_content,
|
||||
)
|
||||
plain_seg = Plain(content)
|
||||
return [reply_seg, plain_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||
@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
@@ -80,12 +146,10 @@ class FuncTool:
|
||||
if not self.mcp_client or not self.mcp_client.session:
|
||||
raise Exception(f"MCP client for {self.name} is not available")
|
||||
# 使用name属性而不是额外的mcp_tool_name
|
||||
if ":" in self.name:
|
||||
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||
actual_tool_name = self.name.split(":")[-1]
|
||||
actual_tool_name = (
|
||||
self.name.split(":")[-1] if ":" in self.name else self.name
|
||||
)
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
return await self.mcp_client.session.call_tool(self.name, args)
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
|
||||
@@ -100,6 +164,7 @@ class MCPClient:
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
@@ -112,17 +177,19 @@ class MCPClient:
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = mcp_server_config.copy()
|
||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||
cfg = cfg["mcpServers"][key_0]
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
@@ -130,11 +197,18 @@ class MCPClient:
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
@@ -148,11 +222,19 @@ class MCPClient:
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -172,7 +254,7 @@ class MCPClient:
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
@@ -180,19 +262,18 @@ class MCPClient:
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
|
||||
|
||||
class FuncCall:
|
||||
@@ -201,8 +282,6 @@ class FuncCall:
|
||||
"""内部加载的 func tools"""
|
||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_service_queue = asyncio.Queue()
|
||||
"""用于外部控制 MCP 服务的启停"""
|
||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
@@ -258,7 +337,7 @@ class FuncCall:
|
||||
return f
|
||||
return None
|
||||
|
||||
async def _init_mcp_clients(self) -> None:
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
{
|
||||
@@ -300,72 +379,32 @@ class FuncCall:
|
||||
)
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
async def mcp_service_selector(self):
|
||||
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||
|
||||
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||
|
||||
{"type": "init"} 初始化所有MCP客户端
|
||||
|
||||
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||
|
||||
{"type": "terminate"} 终止所有MCP客户端
|
||||
|
||||
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||
"""
|
||||
while True:
|
||||
data = await self.mcp_service_queue.get()
|
||||
if data["type"] == "init":
|
||||
if "name" in data:
|
||||
event = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(
|
||||
data["name"], data["cfg"], event
|
||||
)
|
||||
)
|
||||
self.mcp_client_event[data["name"]] = event
|
||||
else:
|
||||
await self._init_mcp_clients()
|
||||
elif data["type"] == "terminate":
|
||||
if "name" in data:
|
||||
# await self._terminate_mcp_client(data["name"])
|
||||
if data["name"] in self.mcp_client_event:
|
||||
self.mcp_client_event[data["name"]].set()
|
||||
self.mcp_client_event.pop(data["name"], None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (
|
||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
for name in self.mcp_client_dict.keys():
|
||||
# await self._terminate_mcp_client(name)
|
||||
# self.mcp_client_event[name].set()
|
||||
if name in self.mcp_client_event:
|
||||
self.mcp_client_event[name].set()
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self, name: str, cfg: dict, event: asyncio.Event
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
await self._terminate_mcp_client(name)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
finally:
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
"""初始化单个MCP客户端"""
|
||||
try:
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
@@ -375,6 +414,7 @@ class FuncCall:
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
@@ -397,16 +437,6 @@ class FuncCall:
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -414,9 +444,9 @@ class FuncCall:
|
||||
try:
|
||||
# 关闭MCP连接
|
||||
await self.mcp_client_dict[name].cleanup()
|
||||
del self.mcp_client_dict[name]
|
||||
self.mcp_client_dict.pop(name)
|
||||
except Exception as e:
|
||||
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
@@ -425,6 +455,103 @@ class FuncCall:
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
if "url" in config:
|
||||
success, error_msg = await _quick_test_mcp_connection(config)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
try:
|
||||
logger.debug(f"testing MCP server connection with config: {config}")
|
||||
await mcp_client.connect_to_server(config, "test")
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
finally:
|
||||
logger.debug("Cleaning up MCP client after testing connection.")
|
||||
await mcp_client.cleanup()
|
||||
return tool_names
|
||||
|
||||
async def enable_mcp_server(
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
Raises:
|
||||
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str | None = None, timeout: float = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server by its name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout (int): Timeout.
|
||||
"""
|
||||
if name:
|
||||
if name not in self.mcp_client_event:
|
||||
return
|
||||
client = self.mcp_client_dict.get(name)
|
||||
self.mcp_client_event[name].set()
|
||||
if not client:
|
||||
return
|
||||
client_running_event = client.running_event
|
||||
try:
|
||||
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if f.origin != "mcp" or f.mcp_server_name != name
|
||||
]
|
||||
else:
|
||||
running_events = [
|
||||
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||
]
|
||||
for key, event in self.mcp_client_event.items():
|
||||
event.set()
|
||||
# waiting for all clients to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.clear()
|
||||
self.mcp_client_dict.clear()
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
@@ -629,8 +756,3 @@ class FuncCall:
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
async def terminate(self):
|
||||
for name in self.mcp_client_dict.keys():
|
||||
await self._terminate_mcp_client(name)
|
||||
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
||||
|
||||
@@ -7,7 +7,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map: dict[str, Provider] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
@@ -169,10 +169,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
||||
)
|
||||
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
@@ -433,5 +430,7 @@ class ProviderManager:
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
await provider_inst.terminate() # type: ignore
|
||||
# 清理 MCP Client 连接
|
||||
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
||||
try:
|
||||
await self.llm_tools.disable_mcp_server()
|
||||
except Exception:
|
||||
logger.error("Error while disabling MCP servers", exc_info=True)
|
||||
|
||||
@@ -2,7 +2,8 @@ import abc
|
||||
from typing import List
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -22,6 +23,7 @@ class ProviderMeta:
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
provider_type: ProviderType
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -40,10 +42,14 @@ class AbstractProvider(abc.ABC):
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
"""获取 Provider 的元数据"""
|
||||
provider_type_name = self.provider_config["type"]
|
||||
meta_data = provider_cls_map.get(provider_type_name)
|
||||
provider_type = meta_data.provider_type if meta_data else None
|
||||
return ProviderMeta(
|
||||
id=self.provider_config["id"],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config["type"],
|
||||
type=provider_type_name,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ class ProviderDify(Provider):
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
|
||||
@@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
|
||||
@@ -187,6 +187,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
func_name_ls = []
|
||||
tool_call_ids = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
# workaround for #1454
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core.star.star_tools import StarTools
|
||||
class Star(CommandParserMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
def __init__(self, context: Context, config: dict | None = None):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
@@ -41,9 +41,17 @@ class Star(CommandParserMixin):
|
||||
tmpl, data, return_url=return_url, options=options
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""当插件被激活时会调用这个方法"""
|
||||
pass
|
||||
|
||||
async def terminate(self):
|
||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||
|
||||
@@ -2,7 +2,7 @@ from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -141,6 +141,10 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Union
|
||||
class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
GEWECHAT = enum.auto()
|
||||
TELEGRAM = enum.auto()
|
||||
WECOM = enum.auto()
|
||||
LARK = enum.auto()
|
||||
@@ -22,7 +21,6 @@ class PlatformAdapterType(enum.Flag):
|
||||
ALL = (
|
||||
AIOCQHTTP
|
||||
| QQOFFICIAL
|
||||
| GEWECHAT
|
||||
| TELEGRAM
|
||||
| WECOM
|
||||
| LARK
|
||||
@@ -39,7 +37,6 @@ class PlatformAdapterType(enum.Flag):
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"gewechat": PlatformAdapterType.GEWECHAT,
|
||||
"telegram": PlatformAdapterType.TELEGRAM,
|
||||
"wecom": PlatformAdapterType.WECOM,
|
||||
"lark": PlatformAdapterType.LARK,
|
||||
|
||||
293
astrbot/core/star/session_llm_manager.py
Normal file
293
astrbot/core/star/session_llm_manager.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionServiceManager:
|
||||
"""管理会话级别的服务启停状态,包括LLM和TTS"""
|
||||
|
||||
# =============================================================================
|
||||
# LLM 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_services.get("llm_enabled")
|
||||
if llm_enabled is not None:
|
||||
return llm_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置LLM状态
|
||||
session_config[session_id]["llm_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# TTS 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查TTS是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
tts_enabled = session_services.get("tts_enabled")
|
||||
if tts_enabled is not None:
|
||||
return tts_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置TTS在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置TTS状态
|
||||
session_config[session_id]["tts_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理TTS请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话整体启停相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_session_enabled(session_id: str) -> bool:
|
||||
"""检查会话是否整体启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
session_enabled = session_services.get("session_enabled")
|
||||
if session_enabled is not None:
|
||||
return session_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_session_status(session_id: str, enabled: bool) -> None:
|
||||
"""设置会话的整体启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置会话整体状态
|
||||
session_config[session_id]["session_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理会话请求(会话整体启停检查)
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_session_enabled(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话命名相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
||||
"""设置会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置自定义名称
|
||||
if custom_name and custom_name.strip():
|
||||
session_config[session_id]["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config[session_id].pop("custom_name", None)
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_display_name(session_id: str) -> str:
|
||||
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 显示名称
|
||||
"""
|
||||
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||
if custom_name:
|
||||
return custom_name
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
# =============================================================================
|
||||
# 通用配置方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_service_config(session_id: str) -> Dict[str, bool]:
|
||||
"""获取指定会话的服务配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
return session_config.get(
|
||||
session_id,
|
||||
{
|
||||
"session_enabled": True, # 默认启用
|
||||
"llm_enabled": True, # 默认启用
|
||||
"tts_enabled": True, # 默认启用
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
|
||||
"""获取所有会话的服务配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, bool]]: 所有会话的服务配置
|
||||
"""
|
||||
return sp.get("session_service_config", {}) or {}
|
||||
142
astrbot/core/star/session_plugin_manager.py
Normal file
142
astrbot/core/star/session_plugin_manager.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
会话插件管理器 - 负责管理每个会话的插件启停状态
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict, List
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionPluginManager:
|
||||
"""管理会话级别的插件启停状态"""
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||
"""检查插件是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
# 如果插件在禁用列表中,返回False
|
||||
if plugin_name in disabled_plugins:
|
||||
return False
|
||||
|
||||
# 如果插件在启用列表中,返回True
|
||||
if plugin_name in enabled_plugins:
|
||||
return True
|
||||
|
||||
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_plugin_status_for_session(
|
||||
session_id: str, plugin_name: str, enabled: bool
|
||||
) -> None:
|
||||
"""设置插件在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": [],
|
||||
}
|
||||
|
||||
session_config = session_plugin_config[session_id]
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
if enabled:
|
||||
# 启用插件
|
||||
if plugin_name in disabled_plugins:
|
||||
disabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in enabled_plugins:
|
||||
enabled_plugins.append(plugin_name)
|
||||
else:
|
||||
# 禁用插件
|
||||
if plugin_name in enabled_plugins:
|
||||
enabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in disabled_plugins:
|
||||
disabled_plugins.append(plugin_name)
|
||||
|
||||
# 保存配置
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put("session_plugin_config", session_plugin_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]:
|
||||
"""获取指定会话的插件配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
"""
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
return session_plugin_config.get(
|
||||
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
handlers: 原始处理器列表
|
||||
|
||||
Returns:
|
||||
List: 过滤后的处理器列表
|
||||
"""
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
session_id = event.unified_msg_origin
|
||||
filtered_handlers = []
|
||||
|
||||
for handler in handlers:
|
||||
# 获取处理器对应的插件
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not plugin:
|
||||
# 如果找不到插件元数据,允许执行(可能是系统插件)
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
# 跳过保留插件(系统插件)
|
||||
if plugin.reserved:
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin.name
|
||||
):
|
||||
filtered_handlers.append(handler)
|
||||
else:
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}"
|
||||
)
|
||||
|
||||
return filtered_handlers
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
@@ -9,6 +10,9 @@ star_registry: list[StarMetadata] = []
|
||||
star_map: dict[str, StarMetadata] = {}
|
||||
"""key 是模块路径,__module__"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Star
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
@@ -29,12 +33,12 @@ class StarMetadata:
|
||||
repo: str | None = None
|
||||
"""插件仓库地址"""
|
||||
|
||||
star_cls_type: type | None = None
|
||||
star_cls_type: type[Star] | None = None
|
||||
"""插件的类对象的类型"""
|
||||
module_path: str | None = None
|
||||
"""插件的模块路径"""
|
||||
|
||||
star_cls: object | None = None
|
||||
star_cls: Star | None = None
|
||||
"""插件的类对象"""
|
||||
module: ModuleType | None = None
|
||||
"""插件的模块对象"""
|
||||
|
||||
@@ -163,7 +163,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str | None = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -187,7 +187,7 @@ class PluginManager:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata:
|
||||
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
|
||||
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
|
||||
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
|
||||
@@ -253,8 +253,8 @@ class PluginManager:
|
||||
|
||||
def _purge_modules(
|
||||
self,
|
||||
module_patterns: list[str] = None,
|
||||
root_dir_name: str = None,
|
||||
module_patterns: list[str] | None = None,
|
||||
root_dir_name: str | None = None,
|
||||
is_reserved: bool = False,
|
||||
):
|
||||
"""从 sys.modules 中移除指定的模块
|
||||
@@ -314,7 +314,7 @@ class PluginManager:
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
|
||||
if smd.name and smd.module_path:
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
|
||||
star_handlers_registry.clear()
|
||||
@@ -331,7 +331,7 @@ class PluginManager:
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
|
||||
if smd.name:
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
result = await self.load(specified_module_path)
|
||||
@@ -460,8 +460,7 @@ class PluginManager:
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
# metadata.config = plugin_config
|
||||
if plugin_config and metadata.star_cls_type:
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -470,7 +469,7 @@ class PluginManager:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
else:
|
||||
elif metadata.star_cls_type:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
@@ -487,6 +486,10 @@ class PluginManager:
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = (
|
||||
star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -495,7 +498,8 @@ class PluginManager:
|
||||
)
|
||||
for handler in related_handlers:
|
||||
handler.handler = functools.partial(
|
||||
handler.handler, metadata.star_cls
|
||||
handler.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
@@ -505,7 +509,8 @@ class PluginManager:
|
||||
):
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(
|
||||
func_tool.handler, metadata.star_cls
|
||||
func_tool.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
@@ -532,13 +537,12 @@ class PluginManager:
|
||||
obj = getattr(module, classes[0])(
|
||||
context=self.context
|
||||
) # 实例化插件类
|
||||
else:
|
||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||
|
||||
metadata = None
|
||||
metadata = self._load_plugin_metadata(
|
||||
plugin_path=plugin_dir_path, plugin_obj=obj
|
||||
)
|
||||
if not metadata:
|
||||
raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
@@ -553,6 +557,10 @@ class PluginManager:
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
full_names = []
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||
metadata.module_path
|
||||
@@ -592,7 +600,7 @@ class PluginManager:
|
||||
metadata.star_handler_full_names = full_names
|
||||
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
@@ -734,6 +742,9 @@ class PluginManager:
|
||||
]:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
if plugin is None:
|
||||
return
|
||||
|
||||
self._purge_modules(
|
||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||
)
|
||||
@@ -795,6 +806,9 @@ class PluginManager:
|
||||
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
||||
return
|
||||
|
||||
if star_metadata.star_cls is None:
|
||||
return
|
||||
|
||||
if hasattr(star_metadata.star_cls, "__del__"):
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
|
||||
@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
|
||||
raise exc_info[1]
|
||||
|
||||
|
||||
def remove_dir(file_path) -> bool:
|
||||
def remove_dir(file_path: str) -> bool:
|
||||
if not os.path.exists(file_path):
|
||||
return True
|
||||
shutil.rmtree(file_path, onerror=on_error)
|
||||
|
||||
29
astrbot/core/utils/session_lock.py
Normal file
29
astrbot/core/utils/session_lock.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
class SessionLockManager:
|
||||
def __init__(self):
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._lock_count: dict[str, int] = defaultdict(int)
|
||||
self._access_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_lock(self, session_id: str):
|
||||
async with self._access_lock:
|
||||
lock = self._locks[session_id]
|
||||
self._lock_count[session_id] += 1
|
||||
|
||||
try:
|
||||
async with lock:
|
||||
yield
|
||||
finally:
|
||||
async with self._access_lock:
|
||||
self._lock_count[session_id] -= 1
|
||||
if self._lock_count[session_id] == 0:
|
||||
self._locks.pop(session_id, None)
|
||||
self._lock_count.pop(session_id, None)
|
||||
|
||||
|
||||
session_lock_manager = SessionLockManager()
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
@@ -24,7 +26,7 @@ class SharedPreferences:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
|
||||
@@ -9,6 +9,7 @@ from .chat import ChatRoute
|
||||
from .tools import ToolsRoute # 导入新的ToolsRoute
|
||||
from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
from .session_management import SessionManagementRoute
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"ToolsRoute",
|
||||
"ConversationRoute",
|
||||
"FileRoute",
|
||||
"SessionManagementRoute",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
@@ -187,15 +190,12 @@ class ConfigRoute(Route):
|
||||
"""辅助函数:测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
provider_name = provider.provider_config.get("id", "Unknown Provider")
|
||||
logger.debug(f"Got provider meta: {meta}")
|
||||
if not provider_name and meta:
|
||||
provider_name = meta.id
|
||||
elif not provider_name:
|
||||
provider_name = "Unknown Provider"
|
||||
provider_capability_type = meta.provider_type
|
||||
|
||||
status_info = {
|
||||
"id": getattr(meta, "id", "Unknown ID"),
|
||||
"model": getattr(meta, "model", "Unknown Model"),
|
||||
"type": getattr(meta, "type", "Unknown Type"),
|
||||
"type": provider_capability_type.value,
|
||||
"name": provider_name,
|
||||
"status": "unavailable", # 默认为不可用
|
||||
"error": None,
|
||||
@@ -203,13 +203,14 @@ class ConfigRoute(Route):
|
||||
logger.debug(
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
|
||||
)
|
||||
|
||||
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
@@ -232,30 +233,72 @@ class ConfigRoute(Route):
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
)
|
||||
else:
|
||||
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
)
|
||||
status_info["error"] = "Test call returned None, but expected an LLMResponse object."
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
|
||||
)
|
||||
status_info["error"] = "Connection timed out after 45 seconds during test call."
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.")
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}")
|
||||
logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}")
|
||||
|
||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
||||
try:
|
||||
# For embedding, we can call the get_embedding method with a short prompt.
|
||||
embedding_result = await provider.get_embedding("health_check")
|
||||
if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)):
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: {str(e)}"
|
||||
|
||||
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
||||
try:
|
||||
# For TTS, we can call the get_audio method with a short prompt.
|
||||
audio_result = await provider.get_audio("你好")
|
||||
if isinstance(audio_result, str) and audio_result:
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: {str(e)}"
|
||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
||||
try:
|
||||
logger.debug(f"Sending health check audio to provider: {status_info['name']}")
|
||||
sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav")
|
||||
if not os.path.exists(sample_audio_path):
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = "STT test failed: sample audio file not found."
|
||||
logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}")
|
||||
else:
|
||||
text_result = await provider.get_text(sample_audio_path)
|
||||
if isinstance(text_result, str) and text_result:
|
||||
status_info["status"] = "available"
|
||||
snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result
|
||||
logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'")
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}"
|
||||
logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: {str(e)}"
|
||||
else:
|
||||
logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}")
|
||||
status_info["status"] = "available"
|
||||
status_info["error"] = "This provider type is not tested and is assumed to be available."
|
||||
|
||||
return status_info
|
||||
|
||||
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
|
||||
@@ -263,7 +306,7 @@ class ConfigRoute(Route):
|
||||
# 记录更详细的traceback信息,但只在是严重错误时
|
||||
if status_code == 500:
|
||||
log_fn(traceback.format_exc())
|
||||
return Response().error(message, status_code=status_code).__dict__
|
||||
return Response().error(message).__dict__
|
||||
|
||||
async def check_one_provider_status(self):
|
||||
"""API: check a single LLM Provider's status by id"""
|
||||
@@ -273,14 +316,12 @@ class ConfigRoute(Route):
|
||||
|
||||
logger.info(f"API call: /config/provider/check_one id={provider_id}")
|
||||
try:
|
||||
all_providers = self.core_lifecycle.star_context.get_all_providers()
|
||||
# replace manual loop with next(filter(...))
|
||||
target = next(
|
||||
(p for p in all_providers if p.provider_config.get("id") == provider_id),
|
||||
None
|
||||
)
|
||||
prov_mgr = self.core_lifecycle.provider_manager
|
||||
target = prov_mgr.inst_map.get(provider_id)
|
||||
|
||||
if not target:
|
||||
return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning)
|
||||
logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.")
|
||||
return Response().error(f"Provider with id '{provider_id}' not found").__dict__
|
||||
|
||||
result = await self._test_single_provider(target)
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
from quart import make_response
|
||||
from astrbot.core import logger, LogBroker
|
||||
from .route import Route, RouteContext
|
||||
from .route import Route, RouteContext, Response
|
||||
|
||||
|
||||
class LogRoute(Route):
|
||||
@@ -10,6 +10,7 @@ class LogRoute(Route):
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||
self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"])
|
||||
|
||||
async def log(self):
|
||||
async def stream():
|
||||
@@ -23,7 +24,6 @@ class LogRoute(Route):
|
||||
**message, # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.07) # 控制发送频率,避免过快
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
@@ -43,3 +43,14 @@ class LogRoute(Route):
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
async def log_history(self):
|
||||
"""获取日志历史"""
|
||||
try:
|
||||
logs = list(self.log_broker.log_cache)
|
||||
return Response().ok(data={
|
||||
"logs": logs,
|
||||
}).__dict__
|
||||
except BaseException as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import traceback
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import ssl
|
||||
import certifi
|
||||
@@ -75,15 +77,33 @@ class PluginRoute(Route):
|
||||
|
||||
async def get_online_plugins(self):
|
||||
custom = request.args.get("custom_registry")
|
||||
force_refresh = request.args.get("force_refresh", "false").lower() == "true"
|
||||
|
||||
cache_file = "data/plugins.json"
|
||||
|
||||
if custom:
|
||||
urls = [custom]
|
||||
else:
|
||||
urls = ["https://api.soulter.top/astrbot/plugins"]
|
||||
urls = [
|
||||
"https://api.soulter.top/astrbot/plugins",
|
||||
"https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json",
|
||||
]
|
||||
|
||||
# 新增:创建 SSL 上下文,使用 certifi 提供的根证书
|
||||
# 如果不是强制刷新,先检查缓存是否有效
|
||||
cached_data = None
|
||||
if not force_refresh:
|
||||
# 先检查MD5是否匹配,如果匹配则使用缓存
|
||||
if await self._is_cache_valid(cache_file):
|
||||
cached_data = self._load_plugin_cache(cache_file)
|
||||
if cached_data:
|
||||
logger.debug("缓存MD5匹配,使用缓存的插件市场数据")
|
||||
return Response().ok(cached_data).__dict__
|
||||
|
||||
# 尝试获取远程数据
|
||||
remote_data = None
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
@@ -91,14 +111,123 @@ class PluginRoute(Route):
|
||||
) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
remote_data = await response.json()
|
||||
|
||||
# 检查远程数据是否为空
|
||||
if not remote_data or (
|
||||
isinstance(remote_data, dict) and len(remote_data) == 0
|
||||
):
|
||||
logger.warning(f"远程插件市场数据为空: {url}")
|
||||
continue # 继续尝试其他URL或使用缓存
|
||||
|
||||
logger.info("成功获取远程插件市场数据")
|
||||
# 获取最新的MD5并保存到缓存
|
||||
current_md5 = await self._get_remote_md5()
|
||||
self._save_plugin_cache(
|
||||
cache_file, remote_data, current_md5
|
||||
)
|
||||
return Response().ok(remote_data).__dict__
|
||||
else:
|
||||
logger.error(f"请求 {url} 失败,状态码:{response.status}")
|
||||
except Exception as e:
|
||||
logger.error(f"请求 {url} 失败,错误:{e}")
|
||||
|
||||
return Response().error("获取插件列表失败").__dict__
|
||||
# 如果远程获取失败,尝试使用缓存数据
|
||||
if not cached_data:
|
||||
cached_data = self._load_plugin_cache(cache_file)
|
||||
|
||||
if cached_data:
|
||||
logger.warning("远程插件市场数据获取失败,使用缓存数据")
|
||||
return Response().ok(cached_data, "使用缓存数据,可能不是最新版本").__dict__
|
||||
|
||||
return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__
|
||||
|
||||
async def _is_cache_valid(self, cache_file: str) -> bool:
|
||||
"""检查缓存是否有效(基于MD5)"""
|
||||
try:
|
||||
if not os.path.exists(cache_file):
|
||||
return False
|
||||
|
||||
# 加载缓存文件
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
cached_md5 = cache_data.get("md5")
|
||||
if not cached_md5:
|
||||
logger.debug("缓存文件中没有MD5信息")
|
||||
return False
|
||||
|
||||
# 获取远程MD5
|
||||
remote_md5 = await self._get_remote_md5()
|
||||
if not remote_md5:
|
||||
logger.warning("无法获取远程MD5,将使用缓存")
|
||||
return True # 如果无法获取远程MD5,认为缓存有效
|
||||
|
||||
is_valid = cached_md5 == remote_md5
|
||||
logger.debug(
|
||||
f"插件数据MD5: 本地={cached_md5}, 远程={remote_md5}, 有效={is_valid}"
|
||||
)
|
||||
return is_valid
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检查缓存有效性失败: {e}")
|
||||
return False
|
||||
|
||||
async def _get_remote_md5(self) -> str:
|
||||
"""获取远程插件数据的MD5"""
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, connector=connector
|
||||
) as session:
|
||||
async with session.get(
|
||||
"https://api.soulter.top/astrbot/plugins-md5"
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("md5", "")
|
||||
else:
|
||||
logger.error(f"获取MD5失败,状态码:{response.status}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"获取远程MD5失败: {e}")
|
||||
return ""
|
||||
|
||||
def _load_plugin_cache(self, cache_file: str):
|
||||
"""加载本地缓存的插件市场数据"""
|
||||
try:
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cache_data = json.load(f)
|
||||
# 检查缓存是否有效
|
||||
if "data" in cache_data and "timestamp" in cache_data:
|
||||
logger.debug(
|
||||
f"加载缓存文件: {cache_file}, 缓存时间: {cache_data['timestamp']}"
|
||||
)
|
||||
return cache_data["data"]
|
||||
except Exception as e:
|
||||
logger.warning(f"加载插件市场缓存失败: {e}")
|
||||
return None
|
||||
|
||||
def _save_plugin_cache(self, cache_file: str, data, md5: str = None):
|
||||
"""保存插件市场数据到本地缓存"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
||||
|
||||
cache_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data,
|
||||
"md5": md5 or "",
|
||||
}
|
||||
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"插件市场数据已缓存到: {cache_file}, MD5: {md5}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存插件市场缓存失败: {e}")
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
|
||||
673
astrbot/dashboard/routes/session_management.py
Normal file
673
astrbot/dashboard/routes/session_management.py
Normal file
@@ -0,0 +1,673 @@
|
||||
import traceback
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class SessionManagementRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db_helper: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/session/list": ("GET", self.list_sessions),
|
||||
"/session/update_persona": ("POST", self.update_session_persona),
|
||||
"/session/update_provider": ("POST", self.update_session_provider),
|
||||
"/session/get_session_info": ("POST", self.get_session_info),
|
||||
"/session/plugins": ("GET", self.get_session_plugins),
|
||||
"/session/update_plugin": ("POST", self.update_session_plugin),
|
||||
"/session/update_llm": ("POST", self.update_session_llm),
|
||||
"/session/update_tts": ("POST", self.update_session_tts),
|
||||
"/session/update_name": ("POST", self.update_session_name),
|
||||
"/session/update_status": ("POST", self.update_session_status),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
# 获取会话对话映射
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
|
||||
# 获取会话提供商偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
|
||||
# 获取可用的 personas
|
||||
personas = self.core_lifecycle.star_context.provider_manager.personas
|
||||
|
||||
# 获取可用的 providers
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
sessions = []
|
||||
|
||||
# 构建会话信息
|
||||
for session_id, conversation_id in session_conversations.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"session_enabled": SessionServiceManager.is_session_enabled(
|
||||
session_id
|
||||
),
|
||||
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"tts_enabled": SessionServiceManager.is_tts_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"platform": session_id.split(":")[0]
|
||||
if ":" in session_id
|
||||
else "unknown",
|
||||
"message_type": session_id.split(":")[1]
|
||||
if session_id.count(":") >= 1
|
||||
else "unknown",
|
||||
"session_name": SessionServiceManager.get_session_display_name(
|
||||
session_id
|
||||
),
|
||||
"session_raw_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
}
|
||||
|
||||
# 获取对话信息
|
||||
conversation = self.db_helper.get_conversation_by_user_id(
|
||||
session_id, conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
# 查找 persona 名称
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = (
|
||||
default_stt_provider.meta().id
|
||||
)
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = (
|
||||
default_tts_provider.meta().id
|
||||
)
|
||||
|
||||
sessions.append(session_info)
|
||||
|
||||
# 获取可用的 personas 和 providers 列表
|
||||
available_personas = [
|
||||
{"name": p["name"], "prompt": p.get("prompt", "")} for p in personas
|
||||
]
|
||||
|
||||
available_chat_providers = []
|
||||
for provider in provider_manager.provider_insts:
|
||||
meta = provider.meta()
|
||||
available_chat_providers.append(
|
||||
{
|
||||
"id": meta.id,
|
||||
"name": meta.id,
|
||||
"model": meta.model,
|
||||
"type": meta.type,
|
||||
}
|
||||
)
|
||||
|
||||
available_stt_providers = []
|
||||
for provider in provider_manager.stt_provider_insts:
|
||||
meta = provider.meta()
|
||||
available_stt_providers.append(
|
||||
{
|
||||
"id": meta.id,
|
||||
"name": meta.id,
|
||||
"model": meta.model,
|
||||
"type": meta.type,
|
||||
}
|
||||
)
|
||||
|
||||
available_tts_providers = []
|
||||
for provider in provider_manager.tts_provider_insts:
|
||||
meta = provider.meta()
|
||||
available_tts_providers.append(
|
||||
{
|
||||
"id": meta.id,
|
||||
"name": meta.id,
|
||||
"model": meta.model,
|
||||
"type": meta.type,
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
"sessions": sessions,
|
||||
"available_personas": available_personas,
|
||||
"available_chat_providers": available_chat_providers,
|
||||
"available_stt_providers": available_stt_providers,
|
||||
"available_tts_providers": available_tts_providers,
|
||||
}
|
||||
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话列表失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_persona(self):
|
||||
"""更新指定会话的 persona"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
persona_name = data.get("persona_name")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if persona_name is None:
|
||||
return Response().error("缺少必要参数: persona_name").__dict__
|
||||
|
||||
# 获取会话当前的对话 ID
|
||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
|
||||
if not conversation_id:
|
||||
# 如果没有对话,创建一个新的对话
|
||||
conversation_id = await conversation_manager.new_conversation(
|
||||
session_id
|
||||
)
|
||||
|
||||
# 更新 persona
|
||||
await conversation_manager.update_conversation_persona_id(
|
||||
session_id, persona_name
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_provider(self):
|
||||
"""更新指定会话的 provider"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
provider_id = data.get("provider_id")
|
||||
# "chat_completion", "speech_to_text", "text_to_speech"
|
||||
provider_type = data.get("provider_type")
|
||||
|
||||
if not session_id or not provider_id or not provider_type:
|
||||
return (
|
||||
Response()
|
||||
.error("缺少必要参数: session_id, provider_id, provider_type")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 转换 provider_type 字符串为枚举
|
||||
if provider_type == "chat_completion":
|
||||
provider_type_enum = ProviderType.CHAT_COMPLETION
|
||||
elif provider_type == "speech_to_text":
|
||||
provider_type_enum = ProviderType.SPEECH_TO_TEXT
|
||||
elif provider_type == "text_to_speech":
|
||||
provider_type_enum = ProviderType.TEXT_TO_SPEECH
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"不支持的 provider_type: {provider_type}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 设置 provider
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_enum,
|
||||
umo=session_id,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}"
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话提供商失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_info(self):
|
||||
"""获取指定会话的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
# 获取会话对话信息
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
conversation_id = session_conversations.get(session_id)
|
||||
|
||||
if not conversation_id:
|
||||
return Response().error(f"会话 {session_id} 未找到对话").__dict__
|
||||
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"tts_enabled": None, # 将在下面设置
|
||||
"platform": session_id.split(":")[0]
|
||||
if ":" in session_id
|
||||
else "unknown",
|
||||
"message_type": session_id.split(":")[1]
|
||||
if session_id.count(":") >= 1
|
||||
else "unknown",
|
||||
"session_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
}
|
||||
|
||||
# 获取TTS状态
|
||||
session_info["tts_enabled"] = (
|
||||
SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
)
|
||||
|
||||
# 获取对话信息
|
||||
conversation = self.db_helper.get_conversation_by_user_id(
|
||||
session_id, conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
personas = provider_manager.personas
|
||||
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = default_stt_provider.meta().id
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = default_tts_provider.meta().id
|
||||
|
||||
return Response().ok(session_info).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话信息失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_plugins(self):
|
||||
"""获取指定会话的插件配置信息"""
|
||||
try:
|
||||
session_id = request.args.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 获取所有已激活的插件
|
||||
all_plugins = []
|
||||
plugin_manager = self.core_lifecycle.plugin_manager
|
||||
|
||||
for plugin in plugin_manager.context.get_all_stars():
|
||||
# 只显示已激活的插件,不包括保留插件
|
||||
if plugin.activated and not plugin.reserved:
|
||||
plugin_name = plugin.name or ""
|
||||
plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin_name
|
||||
)
|
||||
|
||||
all_plugins.append(
|
||||
{
|
||||
"name": plugin_name,
|
||||
"author": plugin.author,
|
||||
"desc": plugin.desc,
|
||||
"enabled": plugin_enabled,
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"plugins": all_plugins,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话插件配置失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话插件配置失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_plugin(self):
|
||||
"""更新指定会话的插件启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
plugin_name = data.get("plugin_name")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if not plugin_name:
|
||||
return Response().error("缺少必要参数: plugin_name").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 验证插件是否存在且已激活
|
||||
plugin_manager = self.core_lifecycle.plugin_manager
|
||||
plugin = plugin_manager.context.get_registered_star(plugin_name)
|
||||
|
||||
if not plugin:
|
||||
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
||||
|
||||
if not plugin.activated:
|
||||
return Response().error(f"插件 {plugin_name} 未激活").__dict__
|
||||
|
||||
if plugin.reserved:
|
||||
return (
|
||||
Response()
|
||||
.error(f"插件 {plugin_name} 是系统保留插件,无法管理")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 使用 SessionPluginManager 更新插件状态
|
||||
SessionPluginManager.set_plugin_status_for_session(
|
||||
session_id, plugin_name, enabled
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"插件 {plugin_name} 已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"plugin_name": plugin_name,
|
||||
"enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话插件状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_llm(self):
|
||||
"""更新指定会话的LLM启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新LLM状态
|
||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_tts(self):
|
||||
"""更新指定会话的TTS启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新TTS状态
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"TTS已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"tts_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话TTS状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_name(self):
|
||||
"""更新指定会话的自定义名称"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
custom_name = data.get("custom_name", "")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新会话名称
|
||||
SessionServiceManager.set_session_custom_name(session_id, custom_name)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"会话名称已更新为: {custom_name if custom_name.strip() else '已清除自定义名称'}",
|
||||
"session_id": session_id,
|
||||
"custom_name": custom_name,
|
||||
"display_name": SessionServiceManager.get_session_display_name(
|
||||
session_id
|
||||
),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话名称失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话名称失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_status(self):
|
||||
"""更新指定会话的整体启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
session_enabled = data.get("session_enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if session_enabled is None:
|
||||
return Response().error("缺少必要参数: session_enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新会话整体状态
|
||||
SessionServiceManager.set_session_status(session_id, session_enabled)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"会话整体状态已更新为: {'启用' if session_enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"session_enabled": session_enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
|
||||
@@ -2,6 +2,7 @@ import traceback
|
||||
import psutil
|
||||
import time
|
||||
import threading
|
||||
import aiohttp
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
@@ -25,6 +26,7 @@ class StatRoute(Route):
|
||||
"/stat/version": ("GET", self.get_version),
|
||||
"/stat/start-time": ("GET", self.get_start_time),
|
||||
"/stat/restart-core": ("POST", self.restart_core),
|
||||
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.register_routes()
|
||||
@@ -45,11 +47,7 @@ class StatRoute(Route):
|
||||
"""将总秒数转换为时分秒组件"""
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return {
|
||||
"hours": hours,
|
||||
"minutes": minutes,
|
||||
"seconds": seconds
|
||||
}
|
||||
return {"hours": hours, "minutes": minutes, "seconds": seconds}
|
||||
|
||||
def is_default_cred(self):
|
||||
username = self.config["dashboard"]["username"]
|
||||
@@ -144,3 +142,40 @@ class StatRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def test_ghproxy_connection(self):
|
||||
"""
|
||||
测试 GitHub 代理连接是否可用。
|
||||
"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
proxy_url: str = data.get("proxy_url")
|
||||
|
||||
if not proxy_url:
|
||||
return Response().error("proxy_url is required").__dict__
|
||||
|
||||
proxy_url = proxy_url.rstrip("/")
|
||||
|
||||
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
|
||||
start_time = time.time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
test_url, timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
end_time = time.time()
|
||||
_ = await response.text()
|
||||
ret = {
|
||||
"latency": round((end_time - start_time) * 1000, 2),
|
||||
}
|
||||
return Response().ok(data=ret).__dict__
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed. Status code: {response.status}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {str(e)}").__dict__
|
||||
|
||||
@@ -26,6 +26,7 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/update": ("POST", self.update_mcp_server),
|
||||
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
|
||||
"/tools/mcp/market": ("GET", self.get_mcp_markets),
|
||||
"/tools/mcp/test": ("POST", self.test_mcp_connection),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
@@ -132,12 +133,19 @@ class ToolsRoute(Route):
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 动态初始化新MCP客户端
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, server_config, timeout=30
|
||||
)
|
||||
except TimeoutError:
|
||||
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
return Response().error("保存配置失败").__dict__
|
||||
@@ -193,31 +201,55 @@ class ToolsRoute(Route):
|
||||
if self.save_mcp_config(config):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
# 如果要激活服务器或者配置已更改
|
||||
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
else:
|
||||
# 客户端不存在,初始化
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError as e:
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, config["mcpServers"][name], timeout=30
|
||||
)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
# 如果要停用服务器
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
self.tool_mgr.mcp_service_queue.put_nowait({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 超时。")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
@@ -239,17 +271,23 @@ class ToolsRoute(Route):
|
||||
if name not in config["mcpServers"]:
|
||||
return Response().error(f"服务器 {name} 不存在").__dict__
|
||||
|
||||
# 删除服务器配置
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 关闭并删除MCP客户端
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
self.tool_mgr.mcp_service_queue.put_nowait({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
return Response().error("保存配置失败").__dict__
|
||||
@@ -281,3 +319,20 @@ class ToolsRoute(Route):
|
||||
except Exception as _:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error("获取市场数据失败").__dict__
|
||||
|
||||
async def test_mcp_connection(self):
|
||||
"""
|
||||
测试 MCP 服务器连接
|
||||
"""
|
||||
try:
|
||||
server_data = await request.json
|
||||
config = server_data.get("mcp_server_config", None)
|
||||
|
||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||
return (
|
||||
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import logging
|
||||
import jwt
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from astrbot.core.config.default import VERSION
|
||||
from quart import Quart, request, jsonify, g
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
from .routes.route import RouteContext, Response
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
|
||||
APP: Quart = None
|
||||
|
||||
@@ -53,6 +57,9 @@ class AstrBotDashboard:
|
||||
self.tools_root = ToolsRoute(self.context, core_lifecycle)
|
||||
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
|
||||
self.file_route = FileRoute(self.context)
|
||||
self.session_management_route = SessionManagementRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
3
changelogs/v3.5.22.md
Normal file
3
changelogs/v3.5.22.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复: 用户环境没有 Docker 时,可能导致死锁(表现为在初始化 AstrBot 的时候卡住)
|
||||
18
changelogs/v3.5.23.md
Normal file
18
changelogs/v3.5.23.md
Normal file
@@ -0,0 +1,18 @@
|
||||
1. 改进: WebUI提供商徽标显示
|
||||
2. 修复:在LLMRequestSubStage中添加对提供商请求处理的调试日志记录
|
||||
3. 修复: 为嵌入模型提供商添加状态检查
|
||||
4. 新增: 支持在WebUI上管理会话
|
||||
5. 新增: 为ProviderMetadata添加provider_type字段并优化提供商可用性测试
|
||||
6. 改进: WebUI聊天页面Markdown代码块
|
||||
7. 修复: 讯飞模型工具使用错误
|
||||
8. 修复: 修复mcp导致的持续占用100% CPU
|
||||
9. 重构: mcp服务器重载机制
|
||||
10. 新增: 为WebChat页面添加文件上传按钮
|
||||
11. 优化: 工具使用页面用户界面
|
||||
12. 新增: 添加测试GitHub加速地址的组件
|
||||
13. 新增: 使用会话锁保证分段回复时的消息发送顺序
|
||||
14. 新增: 实现日志历史记录检索并改进日志流处理
|
||||
15. 杂务: 修改openai的嵌入模型默认维度为1024
|
||||
16. 修复:更新axios版本范围
|
||||
17. chore: remove adapters of WeChat personal account(gewechat)
|
||||
18. 新增: 为AstrBotConfig中的嵌套对象添加展开状态管理
|
||||
@@ -17,7 +17,7 @@
|
||||
"@tiptap/starter-kit": "2.1.7",
|
||||
"@tiptap/vue-3": "2.1.7",
|
||||
"apexcharts": "3.42.0",
|
||||
"axios": "^1.6.2",
|
||||
"axios": ">=1.6.2 <1.10.0 || >1.10.0 <2.0.0",
|
||||
"axios-mock-adapter": "^1.22.0",
|
||||
"chance": "1.1.11",
|
||||
"d3": "^7.9.0",
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
hide-details
|
||||
density="compact"
|
||||
:model-value="getItemEnabled()"
|
||||
:loading="loading"
|
||||
:disabled="loading"
|
||||
v-bind="props"
|
||||
@update:model-value="toggleEnabled"
|
||||
></v-switch>
|
||||
@@ -47,7 +49,6 @@
|
||||
contain
|
||||
width="120"
|
||||
height="120"
|
||||
class="rounded-circle"
|
||||
></v-img>
|
||||
</div>
|
||||
</v-card>
|
||||
@@ -78,6 +79,10 @@ export default {
|
||||
bglogo: {
|
||||
type: String,
|
||||
default: null
|
||||
},
|
||||
loading: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
emits: ['toggle-enabled', 'delete', 'edit'],
|
||||
|
||||
152
dashboard/src/components/shared/ProxySelector.vue
Normal file
152
dashboard/src/components/shared/ProxySelector.vue
Normal file
@@ -0,0 +1,152 @@
|
||||
<template>
|
||||
<h5>GitHub 加速</h5>
|
||||
<v-radio-group class="mt-2" v-model="radioValue" hide-details="true">
|
||||
<v-radio label="不使用 GitHub 加速" value="0"></v-radio>
|
||||
<v-radio value="1">
|
||||
<template v-slot:label>
|
||||
<span>使用 GitHub 加速</span>
|
||||
<v-btn v-if="radioValue === '1'" class="ml-2" @click="testAllProxies" size="x-small"
|
||||
variant="tonal" :loading="loadingTestingConnection">
|
||||
测试代理连通性
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
<div v-if="radioValue === '1'" style="margin-left: 16px;">
|
||||
<v-radio-group v-model="githubProxyRadioControl" class="mt-2" hide-details="true">
|
||||
<v-radio color="success" v-for="(proxy, idx) in githubProxies" :key="proxy" :value="idx">
|
||||
<template v-slot:label>
|
||||
<div class="d-flex align-center">
|
||||
<span class="mr-2">{{ proxy }}</span>
|
||||
<div v-if="proxyStatus[idx]">
|
||||
<v-chip
|
||||
:color="proxyStatus[idx].available ? 'success' : 'error'"
|
||||
size="x-small"
|
||||
class="mr-1">
|
||||
{{ proxyStatus[idx].available ? '可用' : '不可用' }}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-if="proxyStatus[idx].available"
|
||||
color="info"
|
||||
size="x-small">
|
||||
{{ proxyStatus[idx].latency }}ms
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</v-radio>
|
||||
<v-radio color="primary" value="-1" label="自定义">
|
||||
<template v-slot:label v-if="githubProxyRadioControl === '-1'">
|
||||
<v-text-field density="compact" v-model="selectedGitHubProxy" variant="outlined"
|
||||
style="width: 100vw;" placeholder="自定义" hide-details="true">
|
||||
</v-text-field>
|
||||
</template>
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/settings');
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
githubProxies: [
|
||||
"https://edgeone.gh-proxy.com",
|
||||
"https://hk.gh-proxy.com/",
|
||||
"https://gh-proxy.com/",
|
||||
"https://gh.llkk.cc",
|
||||
],
|
||||
githubProxyRadioControl: "0", // the index of the selected proxy
|
||||
selectedGitHubProxy: "",
|
||||
radioValue: "0", // 0: 不使用, 1: 使用
|
||||
loadingTestingConnection: false,
|
||||
testingProxies: {},
|
||||
proxyStatus: {},
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
async testSingleProxy(idx) {
|
||||
this.testingProxies[idx] = true;
|
||||
|
||||
const proxy = this.githubProxies[idx];
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/stat/test-ghproxy-connection', {
|
||||
proxy_url: proxy
|
||||
});
|
||||
console.log(response.data);
|
||||
if (response.status === 200) {
|
||||
this.proxyStatus[idx] = {
|
||||
available: true,
|
||||
latency: Math.round(response.data.data.latency)
|
||||
};
|
||||
} else {
|
||||
this.proxyStatus[idx] = {
|
||||
available: false,
|
||||
latency: 0
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
this.proxyStatus[idx] = {
|
||||
available: false,
|
||||
latency: 0
|
||||
};
|
||||
} finally {
|
||||
this.testingProxies[idx] = false;
|
||||
}
|
||||
},
|
||||
|
||||
async testAllProxies() {
|
||||
this.loadingTestingConnection = true;
|
||||
|
||||
const promises = this.githubProxies.map((proxy, idx) =>
|
||||
this.testSingleProxy(idx)
|
||||
);
|
||||
|
||||
await Promise.all(promises);
|
||||
this.loadingTestingConnection = false;
|
||||
},
|
||||
},
|
||||
mounted() {
|
||||
this.selectedGitHubProxy = localStorage.getItem('selectedGitHubProxy') || "";
|
||||
this.radioValue = localStorage.getItem('githubProxyRadioValue') || "0";
|
||||
this.githubProxyRadioControl = localStorage.getItem('githubProxyRadioControl') || "0";
|
||||
},
|
||||
watch: {
|
||||
selectedGitHubProxy: function (newVal, oldVal) {
|
||||
if (!newVal) {
|
||||
newVal = ""
|
||||
}
|
||||
localStorage.setItem('selectedGitHubProxy', newVal);
|
||||
},
|
||||
radioValue: function (newVal) {
|
||||
localStorage.setItem('githubProxyRadioValue', newVal);
|
||||
if (newVal === "0") {
|
||||
this.selectedGitHubProxy = "";
|
||||
}
|
||||
},
|
||||
githubProxyRadioControl: function (newVal) {
|
||||
localStorage.setItem('githubProxyRadioControl', newVal);
|
||||
if (newVal !== "-1") {
|
||||
this.selectedGitHubProxy = this.githubProxies[newVal] || "";
|
||||
} else {
|
||||
this.selectedGitHubProxy = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
.v-label {
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
</style>
|
||||
@@ -38,6 +38,7 @@ export class I18nLoader {
|
||||
{ name: 'features/chat', path: 'features/chat.json' },
|
||||
{ name: 'features/extension', path: 'features/extension.json' },
|
||||
{ name: 'features/conversation', path: 'features/conversation.json' },
|
||||
{ name: 'features/session-management', path: 'features/session-management.json' },
|
||||
{ name: 'features/tooluse', path: 'features/tool-use.json' },
|
||||
{ name: 'features/provider', path: 'features/provider.json' },
|
||||
{ name: 'features/platform', path: 'features/platform.json' },
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"extensionMarketplace": "Extension Market",
|
||||
"chat": "Chat",
|
||||
"conversation": "Conversations",
|
||||
"sessionManagement": "Session Management",
|
||||
"console": "Console",
|
||||
"alkaid": "Alkaid Lab",
|
||||
"about": "About",
|
||||
|
||||
@@ -32,7 +32,8 @@
|
||||
"cancel": "Cancel",
|
||||
"actions": "Actions",
|
||||
"back": "Back",
|
||||
"selectFile": "Select File"
|
||||
"selectFile": "Select File",
|
||||
"refresh": "Refresh"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "Enabled",
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
{
|
||||
"title": "Session Management",
|
||||
"subtitle": "Manage active sessions and configurations",
|
||||
"buttons": {
|
||||
"refresh": "Refresh",
|
||||
"edit": "Edit",
|
||||
"apply": "Apply Batch Settings",
|
||||
"editName": "Edit Session Name",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "Active Sessions",
|
||||
"sessionCount": "sessions",
|
||||
"noActiveSessions": "No active sessions",
|
||||
"noActiveSessionsDesc": "Sessions will appear here when users interact with the bot"
|
||||
},
|
||||
"search": {
|
||||
"placeholder": "Search sessions...",
|
||||
"platformFilter": "Platform Filter"
|
||||
},
|
||||
"table": {
|
||||
"headers": {
|
||||
"sessionStatus": "Session Status",
|
||||
"sessionInfo": "Session Info",
|
||||
"persona": "Persona",
|
||||
"chatProvider": "Chat Provider",
|
||||
"sttProvider": "STT Provider",
|
||||
"ttsProvider": "TTS Provider",
|
||||
"llmStatus": "LLM Status",
|
||||
"ttsStatus": "TTS Status",
|
||||
"pluginManagement": "Plugin Management"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled"
|
||||
},
|
||||
"persona": {
|
||||
"none": "No Persona"
|
||||
},
|
||||
"batchOperations": {
|
||||
"title": "Batch Operations",
|
||||
"setPersona": "Batch Set Persona",
|
||||
"setChatProvider": "Batch Set Chat Provider",
|
||||
"setSttProvider": "Batch Set STT Provider",
|
||||
"setTtsProvider": "Batch Set TTS Provider",
|
||||
"setLlmStatus": "Batch Set LLM Status",
|
||||
"setTtsStatus": "Batch Set TTS Status",
|
||||
"noSttProvider": "No STT Provider Available",
|
||||
"noTtsProvider": "No TTS Provider Available"
|
||||
},
|
||||
"pluginManagement": {
|
||||
"title": "Plugin Management",
|
||||
"noPlugins": "No available plugins",
|
||||
"noPluginsDesc": "Currently no active plugins",
|
||||
"loading": "Loading plugin list...",
|
||||
"author": "Author"
|
||||
},
|
||||
"nameEditor": {
|
||||
"title": "Edit Session Name",
|
||||
"customName": "Custom Name",
|
||||
"placeholder": "Enter custom session name (leave empty to use original name)",
|
||||
"originalName": "Original Name",
|
||||
"fullSessionId": "Full Session ID",
|
||||
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "Session list refreshed",
|
||||
"personaUpdateSuccess": "Persona updated successfully",
|
||||
"personaUpdateError": "Failed to update persona",
|
||||
"providerUpdateSuccess": "Provider updated successfully",
|
||||
"providerUpdateError": "Failed to update provider",
|
||||
"sessionStatusSuccess": "Session {status}",
|
||||
"llmStatusSuccess": "LLM {status}",
|
||||
"ttsStatusSuccess": "TTS {status}",
|
||||
"statusUpdateError": "Failed to update status",
|
||||
"loadSessionsError": "Failed to load session list",
|
||||
"batchUpdateSuccess": "Successfully batch updated {count} settings",
|
||||
"batchUpdatePartial": "Batch update completed, {success} successful, {error} failed",
|
||||
"loadPluginsError": "Failed to load plugin list",
|
||||
"pluginStatusSuccess": "Plugin {name} {status}",
|
||||
"pluginStatusError": "Failed to update plugin status",
|
||||
"nameUpdateSuccess": "Session name updated successfully",
|
||||
"nameUpdateError": "Failed to update session name"
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,9 @@
|
||||
"buttons": {
|
||||
"refresh": "Refresh",
|
||||
"add": "Add Server",
|
||||
"useTemplate": "Use Template"
|
||||
"useTemplateStdio": "Stdio Template",
|
||||
"useTemplateStreamableHttp": "Streamable HTTP Template",
|
||||
"useTemplateSse": "SSE Template"
|
||||
},
|
||||
"empty": "No MCP servers available, click Add Server to add one",
|
||||
"status": {
|
||||
@@ -28,8 +30,7 @@
|
||||
"functionTools": {
|
||||
"title": "Function Tools",
|
||||
"buttons": {
|
||||
"expand": "Expand",
|
||||
"collapse": "Collapse"
|
||||
"view": "View Tools"
|
||||
},
|
||||
"search": "Search function tools",
|
||||
"empty": "No function tools available",
|
||||
@@ -68,10 +69,6 @@
|
||||
"enable": "Enable Server",
|
||||
"config": "Server Configuration"
|
||||
},
|
||||
"configNotes": {
|
||||
"note1": "1. Some MCP servers may require filling in `API_KEY` or `TOKEN` information in env according to their requirements, please check if filled.",
|
||||
"note2": "2. When url parameter is specified in configuration: if `transport` parameter is also specified as `streamable_http`, Streamable HTTP is used, otherwise SSE connection is used."
|
||||
},
|
||||
"errors": {
|
||||
"configEmpty": "Configuration cannot be empty",
|
||||
"jsonFormat": "JSON format error: {error}",
|
||||
@@ -79,7 +76,8 @@
|
||||
},
|
||||
"buttons": {
|
||||
"cancel": "Cancel",
|
||||
"save": "Save"
|
||||
"save": "Save",
|
||||
"testConnection": "Test Connection"
|
||||
}
|
||||
},
|
||||
"serverDetail": {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"extensionMarketplace": "插件市场",
|
||||
"chat": "聊天",
|
||||
"conversation": "对话数据库",
|
||||
"sessionManagement": "会话管理",
|
||||
"console": "控制台",
|
||||
"alkaid": "Alkaid",
|
||||
"about": "关于",
|
||||
|
||||
@@ -32,7 +32,8 @@
|
||||
"cancel": "取消",
|
||||
"actions": "操作",
|
||||
"back": "返回",
|
||||
"selectFile": "选择文件"
|
||||
"selectFile": "选择文件",
|
||||
"refresh": "刷新"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "启用",
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
{
|
||||
"title": "会话管理",
|
||||
"subtitle": "管理活跃会话和配置",
|
||||
"buttons": {
|
||||
"refresh": "刷新",
|
||||
"edit": "编辑",
|
||||
"apply": "应用批量设置",
|
||||
"editName": "编辑会话名称",
|
||||
"save": "保存",
|
||||
"cancel": "取消"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "活跃会话",
|
||||
"sessionCount": "个会话",
|
||||
"noActiveSessions": "暂无活跃会话",
|
||||
"noActiveSessionsDesc": "当有用户与机器人交互时,会话将会显示在这里"
|
||||
},
|
||||
"search": {
|
||||
"placeholder": "搜索会话...",
|
||||
"platformFilter": "平台筛选"
|
||||
},
|
||||
"table": {
|
||||
"headers": {
|
||||
"sessionStatus": "会话状态",
|
||||
"sessionInfo": "会话信息",
|
||||
"persona": "人格",
|
||||
"chatProvider": "Chat Provider",
|
||||
"sttProvider": "STT Provider",
|
||||
"ttsProvider": "TTS Provider",
|
||||
"llmStatus": "LLM启停",
|
||||
"ttsStatus": "TTS启停",
|
||||
"pluginManagement": "插件管理"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
"enabled": "已启用",
|
||||
"disabled": "已禁用"
|
||||
},
|
||||
"persona": {
|
||||
"none": "无人格"
|
||||
},
|
||||
"batchOperations": {
|
||||
"title": "批量操作",
|
||||
"setPersona": "批量设置人格",
|
||||
"setChatProvider": "批量设置 Chat Provider",
|
||||
"setSttProvider": "批量设置 STT Provider",
|
||||
"setTtsProvider": "批量设置 TTS Provider",
|
||||
"setLlmStatus": "批量设置 LLM 状态",
|
||||
"setTtsStatus": "批量设置 TTS 状态",
|
||||
"noSttProvider": "暂无可用 STT Provider",
|
||||
"noTtsProvider": "暂无可用 TTS Provider"
|
||||
},
|
||||
"pluginManagement": {
|
||||
"title": "插件管理",
|
||||
"noPlugins": "暂无可用插件",
|
||||
"noPluginsDesc": "目前没有激活的插件",
|
||||
"loading": "加载插件列表中...",
|
||||
"author": "作者"
|
||||
},
|
||||
"nameEditor": {
|
||||
"title": "编辑会话名称",
|
||||
"customName": "自定义名称",
|
||||
"placeholder": "输入自定义会话名称(留空则使用原始名称)",
|
||||
"originalName": "原始名称",
|
||||
"fullSessionId": "完整会话ID",
|
||||
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时,会显示一个小感叹号标识(!),鼠标悬停时会显示实际的UMO。"
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "会话列表已刷新",
|
||||
"personaUpdateSuccess": "人格更新成功",
|
||||
"personaUpdateError": "人格更新失败",
|
||||
"providerUpdateSuccess": "Provider 更新成功",
|
||||
"providerUpdateError": "Provider 更新失败",
|
||||
"sessionStatusSuccess": "会话 {status}",
|
||||
"llmStatusSuccess": "LLM {status}",
|
||||
"ttsStatusSuccess": "TTS {status}",
|
||||
"statusUpdateError": "状态更新失败",
|
||||
"loadSessionsError": "加载会话列表失败",
|
||||
"batchUpdateSuccess": "成功批量更新 {count} 项设置",
|
||||
"batchUpdatePartial": "批量更新完成,{success} 项成功,{error} 项失败",
|
||||
"loadPluginsError": "加载插件列表失败",
|
||||
"pluginStatusSuccess": "插件 {name} {status}",
|
||||
"pluginStatusError": "插件状态更新失败",
|
||||
"nameUpdateSuccess": "会话名称更新成功",
|
||||
"nameUpdateError": "会话名称更新失败"
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,9 @@
|
||||
"buttons": {
|
||||
"refresh": "刷新",
|
||||
"add": "新增服务器",
|
||||
"useTemplate": "使用模板"
|
||||
"useTemplateStdio": "Stdio 模板",
|
||||
"useTemplateStreamableHttp": "Streamable HTTP 模板",
|
||||
"useTemplateSse": "SSE 模板"
|
||||
},
|
||||
"empty": "暂无 MCP 服务器,点击 新增服务器 添加",
|
||||
"status": {
|
||||
@@ -28,8 +30,7 @@
|
||||
"functionTools": {
|
||||
"title": "函数工具",
|
||||
"buttons": {
|
||||
"expand": "展开",
|
||||
"collapse": "收起"
|
||||
"view": "查看工具"
|
||||
},
|
||||
"search": "搜索函数工具",
|
||||
"empty": "没有可用的函数工具",
|
||||
@@ -68,10 +69,6 @@
|
||||
"enable": "启用服务器",
|
||||
"config": "服务器配置"
|
||||
},
|
||||
"configNotes": {
|
||||
"note1": "1. 某些 MCP 服务器可能需要按照其要求在 env 中填充 `API_KEY` 或 `TOKEN` 等信息,请注意检查是否填写。",
|
||||
"note2": "2. 当配置中指定 url 参数时:如果还同时指定 `transport` 参数的值为 `streamable_http`,则使用 Steamable HTTP,否则使用 SSE 连接。"
|
||||
},
|
||||
"errors": {
|
||||
"configEmpty": "配置不能为空",
|
||||
"jsonFormat": "JSON 格式错误: {error}",
|
||||
@@ -79,7 +76,8 @@
|
||||
},
|
||||
"buttons": {
|
||||
"cancel": "取消",
|
||||
"save": "保存"
|
||||
"save": "保存",
|
||||
"testConnection": "测试连接"
|
||||
}
|
||||
},
|
||||
"serverDetail": {
|
||||
|
||||
@@ -11,6 +11,7 @@ import zhCNHeader from './locales/zh-CN/core/header.json';
|
||||
import zhCNChat from './locales/zh-CN/features/chat.json';
|
||||
import zhCNExtension from './locales/zh-CN/features/extension.json';
|
||||
import zhCNConversation from './locales/zh-CN/features/conversation.json';
|
||||
import zhCNSessionManagement from './locales/zh-CN/features/session-management.json';
|
||||
import zhCNToolUse from './locales/zh-CN/features/tool-use.json';
|
||||
import zhCNProvider from './locales/zh-CN/features/provider.json';
|
||||
import zhCNPlatform from './locales/zh-CN/features/platform.json';
|
||||
@@ -39,6 +40,7 @@ import enUSHeader from './locales/en-US/core/header.json';
|
||||
import enUSChat from './locales/en-US/features/chat.json';
|
||||
import enUSExtension from './locales/en-US/features/extension.json';
|
||||
import enUSConversation from './locales/en-US/features/conversation.json';
|
||||
import enUSSessionManagement from './locales/en-US/features/session-management.json';
|
||||
import enUSToolUse from './locales/en-US/features/tool-use.json';
|
||||
import enUSProvider from './locales/en-US/features/provider.json';
|
||||
import enUSPlatform from './locales/en-US/features/platform.json';
|
||||
@@ -71,6 +73,7 @@ export const translations = {
|
||||
chat: zhCNChat,
|
||||
extension: zhCNExtension,
|
||||
conversation: zhCNConversation,
|
||||
'session-management': zhCNSessionManagement,
|
||||
tooluse: zhCNToolUse,
|
||||
provider: zhCNProvider,
|
||||
platform: zhCNPlatform,
|
||||
@@ -105,6 +108,7 @@ export const translations = {
|
||||
chat: enUSChat,
|
||||
extension: enUSExtension,
|
||||
conversation: enUSConversation,
|
||||
'session-management': enUSSessionManagement,
|
||||
tooluse: enUSToolUse,
|
||||
provider: enUSProvider,
|
||||
platform: enUSPlatform,
|
||||
|
||||
@@ -58,8 +58,14 @@ const sidebarItem: menu[] = [
|
||||
icon: 'mdi-database',
|
||||
to: '/conversation'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.sessionManagement',
|
||||
icon: 'mdi-account-group',
|
||||
to: '/session-management'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.console',
|
||||
|
||||
icon: 'mdi-console',
|
||||
to: '/console'
|
||||
},
|
||||
|
||||
@@ -51,6 +51,11 @@ const MainRoutes = {
|
||||
path: '/conversation',
|
||||
component: () => import('@/views/ConversationPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'SessionManagement',
|
||||
path: '/session-management',
|
||||
component: () => import('@/views/SessionManagementPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'Console',
|
||||
path: '/console',
|
||||
|
||||
@@ -15,7 +15,22 @@ export const useCommonStore = defineStore({
|
||||
pluginMarketData: [],
|
||||
}),
|
||||
actions: {
|
||||
createEventSource() {
|
||||
async createEventSource() {
|
||||
|
||||
const fetchLogHistory = async () => {
|
||||
try {
|
||||
const res = await axios.get('/api/log-history');
|
||||
if (res.data.data.logs) {
|
||||
this.log_cache.push(...res.data.data.logs);
|
||||
} else {
|
||||
this.log_cache = [];
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch log history:', err);
|
||||
}
|
||||
};
|
||||
await fetchLogHistory();
|
||||
|
||||
if (this.eventSource) {
|
||||
return
|
||||
}
|
||||
@@ -40,7 +55,24 @@ export const useCommonStore = defineStore({
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
let incompleteLine = ""; // 用于存储不完整的行
|
||||
|
||||
const handleIncompleteLine = (line) => {
|
||||
incompleteLine += line;
|
||||
// if can parse as JSON, return it
|
||||
try {
|
||||
const data_json = JSON.parse(incompleteLine);
|
||||
incompleteLine = ""; // 清空不完整行
|
||||
return data_json;
|
||||
} catch (e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const processStream = ({ done, value }) => {
|
||||
// get bytes length
|
||||
const bytesLength = value ? value.byteLength : 0;
|
||||
console.log(`Received ${bytesLength} bytes from live log`);
|
||||
if (done) {
|
||||
console.log('SSE stream closed');
|
||||
setTimeout(() => {
|
||||
@@ -53,6 +85,9 @@ export const useCommonStore = defineStore({
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.split('\n\n');
|
||||
lines.forEach(line => {
|
||||
if (!line.trim()) {
|
||||
return;
|
||||
}
|
||||
if (line.startsWith('data:')) {
|
||||
const data = line.substring(5).trim();
|
||||
// {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"}
|
||||
@@ -60,21 +95,29 @@ export const useCommonStore = defineStore({
|
||||
try {
|
||||
data_json = JSON.parse(data);
|
||||
} catch (e) {
|
||||
console.error('Invalid JSON:', data);
|
||||
data_json = {
|
||||
type: 'log',
|
||||
data: data,
|
||||
level: 'INFO',
|
||||
time: new Date().toISOString(),
|
||||
console.warn('Invalid JSON:', data);
|
||||
// 尝试处理不完整的行
|
||||
const parsedData = handleIncompleteLine(data);
|
||||
if (parsedData) {
|
||||
data_json = parsedData;
|
||||
} else {
|
||||
return; // 如果无法解析,跳过当前行
|
||||
}
|
||||
}
|
||||
if (data_json.type === 'log') {
|
||||
// let log = data_json.data
|
||||
this.log_cache.push(data_json);
|
||||
if (this.log_cache.length > this.log_cache_max_len) {
|
||||
this.log_cache.shift();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const parsedData = handleIncompleteLine(line);
|
||||
if (parsedData && parsedData.type === 'log') {
|
||||
this.log_cache.push(parsedData);
|
||||
if (this.log_cache.length > this.log_cache_max_len) {
|
||||
this.log_cache.shift();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return reader.read().then(processStream);
|
||||
@@ -116,7 +159,11 @@ export const useCommonStore = defineStore({
|
||||
if (!force && this.pluginMarketData.length > 0) {
|
||||
return Promise.resolve(this.pluginMarketData);
|
||||
}
|
||||
return axios.get('/api/plugin/market_list')
|
||||
|
||||
// 如果是强制刷新,添加 force_refresh 参数
|
||||
const url = force ? '/api/plugin/market_list?force_refresh=true' : '/api/plugin/market_list';
|
||||
|
||||
return axios.get(url)
|
||||
.then((res) => {
|
||||
let data = []
|
||||
for (let key in res.data.data) {
|
||||
|
||||
@@ -36,11 +36,13 @@ const PurpleThemeDark: ThemeTypes = {
|
||||
gray100: '#cccccccc',
|
||||
primary200: '#90caf9',
|
||||
secondary200: '#b39ddb',
|
||||
background: '#111111',
|
||||
background: '#1d1d1d',
|
||||
overlay: '#111111aa',
|
||||
codeBg: '#282833',
|
||||
preBg: 'rgb(23, 23, 23)',
|
||||
code: '#ffffffdd',
|
||||
chatMessageBubble: '#2d2e30',
|
||||
mcpCardBg: '#2a2a2a',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ const PurpleTheme: ThemeTypes = {
|
||||
borderLight: '#d0d0d0',
|
||||
border: '#d0d0d0',
|
||||
inputBorder: '#787878',
|
||||
containerBg: '#f7f1f6',
|
||||
containerBg: '#f9fafcf4',
|
||||
surface: '#fff',
|
||||
'on-surface-variant': '#fff',
|
||||
facebook: '#4267b2',
|
||||
@@ -36,11 +36,13 @@ const PurpleTheme: ThemeTypes = {
|
||||
gray100: '#fafafacc',
|
||||
primary200: '#90caf9',
|
||||
secondary200: '#b39ddb',
|
||||
background: '#f9fafcf4',
|
||||
background: '#ffffff',
|
||||
overlay: '#ffffffaa',
|
||||
codeBg: '#f5f0ff',
|
||||
code: '#673ab7',
|
||||
codeBg: '#ececec',
|
||||
preBg: 'rgb(249, 249, 249)',
|
||||
code: 'rgb(13, 13, 13)',
|
||||
chatMessageBubble: '#e7ebf4',
|
||||
mcpCardBg: '#f7f2f9',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -34,7 +34,9 @@ export type ThemeTypes = {
|
||||
primary200?: string;
|
||||
secondary200?: string;
|
||||
codeBg?: string;
|
||||
preBg?: string;
|
||||
code?: string;
|
||||
chatMessageBubble?: string;
|
||||
mcpCardBg?: string;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -126,17 +126,17 @@
|
||||
<span>Hello, I'm</span>
|
||||
<span class="bot-name">AstrBot ⭐</span>
|
||||
</div>
|
||||
<div class="welcome-hint">
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.type') }}</span>
|
||||
<code>help</code>
|
||||
<span>{{ tm('shortcuts.help') }} 😊</span>
|
||||
</div>
|
||||
<div class="welcome-hint">
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.longPress') }}</span>
|
||||
<code>Ctrl + B</code>
|
||||
<span>{{ tm('shortcuts.voiceRecord') }} 🎤</span>
|
||||
</div>
|
||||
<div class="welcome-hint">
|
||||
<div class="welcome-hint markdown-content">
|
||||
<span>{{ t('core.common.press') }}</span>
|
||||
<code>Ctrl + V</code>
|
||||
<span>{{ tm('shortcuts.pasteImage') }} 🏞️</span>
|
||||
@@ -151,7 +151,7 @@
|
||||
<div class="message-bubble user-bubble"
|
||||
:class="{ 'has-audio': msg.audio_url }"
|
||||
:style="{ backgroundColor: isDark ? '#2d2e30' : '#e7ebf4' }">
|
||||
<span>{{ msg.message }}</span>
|
||||
<pre style="font-family: inherit; white-space: pre-wrap; word-wrap: break-word;">{{ msg.message }}</pre>
|
||||
|
||||
<!-- 图片附件 -->
|
||||
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
|
||||
@@ -218,14 +218,17 @@
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px; padding: 4px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown"
|
||||
@click:clear="clearMessage" placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 12px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div
|
||||
style="display: flex; justify-content: space-between; align-items: center; padding: 0px 8px;">
|
||||
<div style="display: flex; justify-content: flex-start; margin-top: 8px;">
|
||||
<div style="display: flex; justify-content: flex-start; margin-top: 4px;">
|
||||
<!-- 选择提供商和模型 -->
|
||||
<ProviderModelSelector ref="providerModelSelector" />
|
||||
</div>
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px;">
|
||||
<input type="file" ref="imageInput" @change="handleFileSelect" accept="image/*" style="display: none" multiple />
|
||||
<v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple"
|
||||
class="add-btn" size="small" />
|
||||
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
|
||||
class="send-btn" size="small" />
|
||||
@@ -668,12 +671,7 @@ export default {
|
||||
};
|
||||
},
|
||||
|
||||
async handlePaste(event) {
|
||||
console.log('Pasting image...');
|
||||
const items = event.clipboardData.items;
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
if (items[i].type.indexOf('image') !== -1) {
|
||||
const file = items[i].getAsFile();
|
||||
async processAndUploadImage(file) {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
@@ -691,11 +689,26 @@ export default {
|
||||
} catch (err) {
|
||||
console.error('Error uploading image:', err);
|
||||
}
|
||||
},
|
||||
|
||||
async handlePaste(event) {
|
||||
console.log('Pasting image...');
|
||||
const items = event.clipboardData.items;
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
if (items[i].type.indexOf('image') !== -1) {
|
||||
const file = items[i].getAsFile();
|
||||
this.processAndUploadImage(file);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
removeImage(index) {
|
||||
// Revoke the blob URL to prevent memory leaks
|
||||
const urlToRevoke = this.stagedImagesUrl[index];
|
||||
if (urlToRevoke && urlToRevoke.startsWith('blob:')) {
|
||||
URL.revokeObjectURL(urlToRevoke);
|
||||
}
|
||||
|
||||
this.stagedImagesName.splice(index, 1);
|
||||
this.stagedImagesUrl.splice(index, 1);
|
||||
},
|
||||
@@ -703,6 +716,21 @@ export default {
|
||||
clearMessage() {
|
||||
this.prompt = '';
|
||||
},
|
||||
|
||||
triggerImageInput() {
|
||||
this.$refs.imageInput.click();
|
||||
},
|
||||
|
||||
handleFileSelect(event) {
|
||||
const files = event.target.files;
|
||||
if (files) {
|
||||
for (const file of files) {
|
||||
this.processAndUploadImage(file);
|
||||
}
|
||||
}
|
||||
// Reset the input value to allow selecting the same file again
|
||||
event.target.value = '';
|
||||
},
|
||||
getConversations() {
|
||||
axios.get('/api/chat/conversations').then(response => {
|
||||
this.conversations = response.data.data;
|
||||
@@ -846,33 +874,42 @@ export default {
|
||||
// URL is already updated in newConversation method
|
||||
}
|
||||
|
||||
// 保存当前要发送的数据到临时变量
|
||||
const promptToSend = this.prompt.trim();
|
||||
const imageNamesToSend = [...this.stagedImagesName];
|
||||
const audioNameToSend = this.stagedAudioUrl;
|
||||
|
||||
// 立即清空输入和附件预览
|
||||
this.prompt = '';
|
||||
this.stagedImagesName = [];
|
||||
this.stagedImagesUrl = [];
|
||||
this.stagedAudioUrl = "";
|
||||
|
||||
// Create a message object with actual URLs for display
|
||||
const userMessage = {
|
||||
type: 'user',
|
||||
message: this.prompt.trim(), // 使用 trim() 去除前后空格
|
||||
message: promptToSend,
|
||||
image_url: [],
|
||||
audio_url: null
|
||||
};
|
||||
|
||||
// Convert image filenames to blob URLs for display
|
||||
if (this.stagedImagesName.length > 0) {
|
||||
for (let i = 0; i < this.stagedImagesName.length; i++) {
|
||||
// If it's just a filename, get the blob URL
|
||||
if (!this.stagedImagesName[i].startsWith('blob:')) {
|
||||
const imgUrl = await this.getMediaFile(this.stagedImagesName[i]);
|
||||
userMessage.image_url.push(imgUrl);
|
||||
} else {
|
||||
userMessage.image_url.push(this.stagedImagesName[i]);
|
||||
}
|
||||
if (imageNamesToSend.length > 0) {
|
||||
const imagePromises = imageNamesToSend.map(name => {
|
||||
if (!name.startsWith('blob:')) {
|
||||
return this.getMediaFile(name);
|
||||
}
|
||||
return Promise.resolve(name);
|
||||
});
|
||||
userMessage.image_url = await Promise.all(imagePromises);
|
||||
}
|
||||
|
||||
// Convert audio filename to blob URL for display
|
||||
if (this.stagedAudioUrl) {
|
||||
if (!this.stagedAudioUrl.startsWith('blob:')) {
|
||||
userMessage.audio_url = await this.getMediaFile(this.stagedAudioUrl);
|
||||
if (audioNameToSend) {
|
||||
if (!audioNameToSend.startsWith('blob:')) {
|
||||
userMessage.audio_url = await this.getMediaFile(audioNameToSend);
|
||||
} else {
|
||||
userMessage.audio_url = this.stagedAudioUrl;
|
||||
userMessage.audio_url = audioNameToSend;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -894,17 +931,15 @@ export default {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message: this.prompt.trim(), // 确保发送的消息已去除前后空格
|
||||
message: promptToSend,
|
||||
conversation_id: this.currCid,
|
||||
image_url: this.stagedImagesName,
|
||||
audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [],
|
||||
image_url: imageNamesToSend,
|
||||
audio_url: audioNameToSend ? [audioNameToSend] : [],
|
||||
selected_provider: selectedProviderId,
|
||||
selected_model: selectedModelName
|
||||
})
|
||||
});
|
||||
|
||||
this.prompt = ''; // 清空输入框;
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
@@ -1003,11 +1038,7 @@ export default {
|
||||
}
|
||||
}
|
||||
|
||||
// Clear input after successful send
|
||||
this.prompt = '';
|
||||
this.stagedImagesName = [];
|
||||
this.stagedImagesUrl = [];
|
||||
this.stagedAudioUrl = "";
|
||||
// Input and attachments are already cleared
|
||||
this.loadingChat = false;
|
||||
|
||||
// get the latest conversations
|
||||
@@ -1479,11 +1510,11 @@ export default {
|
||||
}
|
||||
|
||||
.welcome-hint code {
|
||||
background-color: var(--v-theme-codeBg);
|
||||
background-color: rgb(var(--v-theme-codeBg));
|
||||
padding: 2px 6px;
|
||||
margin: 0 4px;
|
||||
border-radius: 4px;
|
||||
color: var(--v-theme-code);
|
||||
color: rgb(var(--v-theme-code));
|
||||
font-family: 'Fira Code', monospace;
|
||||
font-size: 13px;
|
||||
}
|
||||
@@ -1571,6 +1602,8 @@ export default {
|
||||
.bot-bubble {
|
||||
border: 1px solid var(--v-theme-border);
|
||||
color: var(--v-theme-primaryText);
|
||||
font-size: 16px;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.user-avatar,
|
||||
@@ -1749,8 +1782,8 @@ export default {
|
||||
}
|
||||
|
||||
.markdown-content p {
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: .5rem;
|
||||
margin-bottom: .5rem;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
@@ -1763,7 +1796,7 @@ export default {
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
background-color: var(--v-theme-codeBg);
|
||||
background-color: rgb(var(--v-theme-codeBg));
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
font-family: 'Fira Code', monospace;
|
||||
@@ -1787,7 +1820,9 @@ export default {
|
||||
/* 自定义代码高亮样式 */
|
||||
.markdown-content pre {
|
||||
border: 1px solid var(--v-theme-border);
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
||||
background-color: rgb(var(--v-theme-preBg));
|
||||
border-radius: 16px;
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
/* 确保highlight.js的样式正确应用 */
|
||||
@@ -1990,6 +2025,7 @@ export default {
|
||||
}
|
||||
|
||||
.embedded-audio {
|
||||
width: 300px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
|
||||
@@ -364,7 +364,6 @@ export default {
|
||||
'telegram': 'blue-lighten-1',
|
||||
'qq_official': 'purple-lighten-1',
|
||||
'qq_official_webhook': 'purple-lighten-2',
|
||||
'gewechat': 'green-lighten-1',
|
||||
'aiocqhttp': 'deep-purple-lighten-1',
|
||||
'lark': 'cyan-darken-1',
|
||||
'wecom': 'green-darken-1',
|
||||
|
||||
@@ -3,6 +3,7 @@ import ExtensionCard from '@/components/shared/ExtensionCard.vue';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
|
||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||
import axios from 'axios';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
@@ -70,6 +71,7 @@ const uploadTab = ref('file');
|
||||
const showPluginFullName = ref(false);
|
||||
const marketSearch = ref("");
|
||||
const filterKeys = ['name', 'desc', 'author'];
|
||||
const refreshingMarket = ref(false);
|
||||
|
||||
const plugin_handler_info_headers = computed(() => [
|
||||
{ title: tm('table.headers.eventType'), key: 'event_type_h' },
|
||||
@@ -559,6 +561,25 @@ const newExtension = async () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 刷新插件市场数据
|
||||
const refreshPluginMarket = async () => {
|
||||
refreshingMarket.value = true;
|
||||
try {
|
||||
// 强制刷新插件市场数据
|
||||
const data = await commonStore.getPluginCollections(true);
|
||||
pluginMarketData.value = data;
|
||||
trimExtensionName();
|
||||
checkAlreadyInstalled();
|
||||
checkUpdate();
|
||||
|
||||
toast(tm('messages.refreshSuccess'), "success");
|
||||
} catch (err) {
|
||||
toast(tm('messages.refreshFailed') + " " + err, "error");
|
||||
} finally {
|
||||
refreshingMarket.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
// 生命周期
|
||||
onMounted(async () => {
|
||||
await getExtensions();
|
||||
@@ -622,27 +643,12 @@ onMounted(async () => {
|
||||
<!-- 搜索栏 - 在移动端时独占一行 -->
|
||||
<v-row class="mb-2">
|
||||
<v-col cols="12" sm="6" md="4" lg="3">
|
||||
<v-text-field
|
||||
v-if="activeTab == 'market'"
|
||||
v-model="marketSearch"
|
||||
density="compact"
|
||||
:label="tm('search.marketPlaceholder')"
|
||||
prepend-inner-icon="mdi-magnify"
|
||||
variant="solo-filled"
|
||||
flat
|
||||
hide-details
|
||||
single-line>
|
||||
<v-text-field v-if="activeTab == 'market'" v-model="marketSearch" density="compact"
|
||||
:label="tm('search.marketPlaceholder')" prepend-inner-icon="mdi-magnify" variant="solo-filled" flat
|
||||
hide-details single-line>
|
||||
</v-text-field>
|
||||
<v-text-field
|
||||
v-else
|
||||
v-model="pluginSearch"
|
||||
density="compact"
|
||||
:label="tm('search.placeholder')"
|
||||
prepend-inner-icon="mdi-magnify"
|
||||
variant="solo-filled"
|
||||
flat
|
||||
hide-details
|
||||
single-line>
|
||||
<v-text-field v-else v-model="pluginSearch" density="compact" :label="tm('search.placeholder')"
|
||||
prepend-inner-icon="mdi-magnify" variant="solo-filled" flat hide-details single-line>
|
||||
</v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -678,14 +684,12 @@ onMounted(async () => {
|
||||
<v-icon>mdi-plus</v-icon>
|
||||
{{ tm('buttons.install') }}
|
||||
</v-btn>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="auto" md="6" class="ml-auto">
|
||||
<v-col cols="12" sm="auto" class="ml-auto">
|
||||
<v-dialog max-width="500px" v-if="extension_data.message">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" icon size="small" color="error" class="ml-2" variant="tonal">
|
||||
<v-icon>mdi-alert-circle</v-icon>
|
||||
<v-badge dot color="error" floating></v-badge>
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:default="{ isActive }">
|
||||
@@ -706,6 +710,7 @@ onMounted(async () => {
|
||||
</template>
|
||||
</v-dialog>
|
||||
</v-col>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<v-fade-transition hide-on-leave>
|
||||
@@ -726,7 +731,8 @@ onMounted(async () => {
|
||||
<div>
|
||||
<div class="text-subtitle-1 font-weight-medium">{{ item.name }}</div>
|
||||
<div v-if="item.reserved" class="d-flex align-center mt-1">
|
||||
<v-chip color="primary" size="x-small" class="font-weight-medium">{{ tm('status.system') }}</v-chip>
|
||||
<v-chip color="primary" size="x-small" class="font-weight-medium">{{ tm('status.system')
|
||||
}}</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -847,8 +853,8 @@ onMounted(async () => {
|
||||
|
||||
<!-- <small style="color: var(--v-theme-secondaryText);">每个插件都是作者无偿提供的的劳动成果。如果您喜欢某个插件,请 Star!</small> -->
|
||||
|
||||
<v-btn icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px; z-index: 10000" @click="dialog = true"
|
||||
color="darkprimary">
|
||||
<v-btn icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px; z-index: 10000"
|
||||
@click="dialog = true" color="darkprimary">
|
||||
</v-btn>
|
||||
|
||||
<div v-if="pinnedPlugins.length > 0" class="mt-4">
|
||||
@@ -865,9 +871,21 @@ onMounted(async () => {
|
||||
<div class="mt-4">
|
||||
<div class="d-flex align-center mb-2" style="justify-content: space-between;">
|
||||
<h2>{{ tm('market.allPlugins') }}</h2>
|
||||
<div class="d-flex align-center">
|
||||
<v-btn
|
||||
variant="tonal"
|
||||
size="small"
|
||||
@click="refreshPluginMarket"
|
||||
:loading="refreshingMarket"
|
||||
class="mr-2"
|
||||
>
|
||||
<v-icon>mdi-refresh</v-icon>
|
||||
{{ tm('buttons.refresh') }}
|
||||
</v-btn>
|
||||
<v-switch v-model="showPluginFullName" :label="tm('market.showFullName')" hide-details density="compact"
|
||||
style="margin-left: 12px" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<v-col cols="12" md="12" style="padding: 0px;">
|
||||
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
|
||||
@@ -904,7 +922,8 @@ onMounted(async () => {
|
||||
</template>
|
||||
<template v-slot:item.tags="{ item }">
|
||||
<span v-if="item.tags.length === 0">-</span>
|
||||
<v-chip v-for="tag in item.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'" size="x-small" v-show="tag !== 'danger'">
|
||||
<v-chip v-for="tag in item.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'"
|
||||
size="x-small" v-show="tag !== 'danger'">
|
||||
{{ tag }}</v-chip>
|
||||
</template>
|
||||
<template v-slot:item.actions="{ item }">
|
||||
@@ -958,7 +977,8 @@ onMounted(async () => {
|
||||
<v-icon icon="mdi-alert" color="warning" size="64" class="mb-4"></v-icon>
|
||||
<div class="text-h5 mb-2">{{ tm('dialogs.platformConfig.noAdapters') }}</div>
|
||||
<div class="text-body-1 mb-4">{{ tm('dialogs.platformConfig.noAdaptersDesc') }}</div>
|
||||
<v-btn color="primary" to="/platforms" variant="elevated">{{ tm('dialogs.platformConfig.goPlatforms') }}</v-btn>
|
||||
<v-btn color="primary" to="/platforms" variant="elevated">{{ tm('dialogs.platformConfig.goPlatforms')
|
||||
}}</v-btn>
|
||||
</div>
|
||||
|
||||
<v-sheet v-else class="rounded-lg overflow-hidden">
|
||||
@@ -1002,7 +1022,8 @@ onMounted(async () => {
|
||||
<td>
|
||||
<div class="d-flex align-center">
|
||||
{{ plugin.name }}
|
||||
<v-chip v-if="plugin.reserved" color="primary" size="x-small" class="ml-2">{{ tm('status.system') }}</v-chip>
|
||||
<v-chip v-if="plugin.reserved" color="primary" size="x-small" class="ml-2">{{ tm('status.system')
|
||||
}}</v-chip>
|
||||
</div>
|
||||
<div class="text-caption text-grey">{{ plugin.desc }}</div>
|
||||
</td>
|
||||
@@ -1018,8 +1039,8 @@ onMounted(async () => {
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" text @click="platformEnableDialog = false">{{ tm('buttons.close') }}</v-btn>
|
||||
<v-btn v-if="platformEnableData.platforms.length > 0" color="primary"
|
||||
@click="savePlatformEnableConfig">{{ tm('buttons.save') }}</v-btn>
|
||||
<v-btn v-if="platformEnableData.platforms.length > 0" color="primary" @click="savePlatformEnableConfig">{{
|
||||
tm('buttons.save') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -1058,7 +1079,8 @@ onMounted(async () => {
|
||||
|
||||
<div style="margin-top: 32px;">
|
||||
<h3>{{ tm('dialogs.loading.logs') }}</h3>
|
||||
<ConsoleDisplayer historyNum="10" style="height: 200px; margin-top: 16px; margin-bottom: 24px;"></ConsoleDisplayer>
|
||||
<ConsoleDisplayer historyNum="10" style="height: 200px; margin-top: 16px; margin-bottom: 24px;">
|
||||
</ConsoleDisplayer>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
@@ -1099,7 +1121,8 @@ onMounted(async () => {
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="blue-darken-1" variant="text" @click="showPluginInfoDialog = false">{{ tm('buttons.close') }}</v-btn>
|
||||
<v-btn color="blue-darken-1" variant="text" @click="showPluginInfoDialog = false">{{ tm('buttons.close')
|
||||
}}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -1146,22 +1169,10 @@ onMounted(async () => {
|
||||
<v-window v-model="uploadTab" class="mt-4">
|
||||
<v-window-item value="file">
|
||||
<div class="d-flex flex-column align-center justify-center pa-4">
|
||||
<v-file-input
|
||||
ref="fileInput"
|
||||
v-model="upload_file"
|
||||
:label="tm('upload.selectFile')"
|
||||
accept=".zip"
|
||||
hide-details
|
||||
hide-input
|
||||
class="d-none"
|
||||
></v-file-input>
|
||||
<v-file-input ref="fileInput" v-model="upload_file" :label="tm('upload.selectFile')" accept=".zip"
|
||||
hide-details hide-input class="d-none"></v-file-input>
|
||||
|
||||
<v-btn
|
||||
color="primary"
|
||||
size="large"
|
||||
prepend-icon="mdi-upload"
|
||||
@click="$refs.fileInput.click()"
|
||||
>
|
||||
<v-btn color="primary" size="large" prepend-icon="mdi-upload" @click="$refs.fileInput.click()">
|
||||
{{ tm('buttons.selectFile') }}
|
||||
</v-btn>
|
||||
|
||||
@@ -1182,14 +1193,12 @@ onMounted(async () => {
|
||||
|
||||
<v-window-item value="url">
|
||||
<div class="pa-4">
|
||||
<v-text-field
|
||||
v-model="extension_url"
|
||||
:label="tm('upload.enterUrl')"
|
||||
variant="outlined"
|
||||
prepend-inner-icon="mdi-link"
|
||||
hide-details
|
||||
placeholder="https://github.com/username/repo"
|
||||
></v-text-field>
|
||||
<v-text-field v-model="extension_url" :label="tm('upload.enterUrl')" variant="outlined"
|
||||
prepend-inner-icon="mdi-link" hide-details
|
||||
placeholder="https://github.com/username/repo"></v-text-field>
|
||||
<div class="mt-4">
|
||||
<ProxySelector></ProxySelector>
|
||||
</div>
|
||||
</div>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
|
||||
@@ -265,7 +265,7 @@ export default {
|
||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||
} else if (name === 'wecom') {
|
||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||
} else if (name === 'gewechat' || name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||
} else if (name === 'lark') {
|
||||
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
|
||||
@@ -290,7 +290,6 @@ export default {
|
||||
"qq_official": "https://astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||
"wecom": "https://astrbot.app/deploy/platform/wecom.html",
|
||||
"gewechat": "https://astrbot.app/deploy/platform/wechat/gewechat.html",
|
||||
"lark": "https://astrbot.app/deploy/platform/lark.html",
|
||||
"telegram": "https://astrbot.app/deploy/platform/telegram.html",
|
||||
"dingtalk": "https://astrbot.app/deploy/platform/dingtalk.html",
|
||||
|
||||
@@ -60,6 +60,7 @@
|
||||
title-field="id"
|
||||
enabled-field="enable"
|
||||
@toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)"
|
||||
@delete="deleteProvider"
|
||||
@edit="configExistingProvider">
|
||||
<template v-slot:details="{ item }">
|
||||
@@ -199,7 +200,7 @@
|
||||
</v-card-text>
|
||||
</div>
|
||||
<div class="provider-card-logo">
|
||||
<img :src="getProviderIcon(name)" v-if="getProviderIcon(name)" class="provider-logo-img">
|
||||
<img :src="getProviderIcon(template.provider)" v-if="getProviderIcon(template.provider)" class="provider-logo-img">
|
||||
<div v-else class="provider-logo-fallback">
|
||||
{{ name[0].toUpperCase() }}
|
||||
</div>
|
||||
@@ -541,34 +542,27 @@ export default {
|
||||
// 获取提供商类型对应的图标
|
||||
getProviderIcon(type) {
|
||||
const icons = {
|
||||
'OpenAI': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||
'Azure OpenAI': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||
'Whisper': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||
'xAI': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
|
||||
'Anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
|
||||
'Ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
|
||||
'Gemini': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
||||
'Gemini(OpenAI兼容)': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
||||
'DeepSeek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
|
||||
'智谱 AI': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
|
||||
'硅基流动': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
|
||||
'Kimi': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
||||
'PPIO派欧云': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
||||
'Dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
||||
'阿里云百炼': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
||||
'FastGPT': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
||||
'LM Studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
||||
'FishAudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
|
||||
'Azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
|
||||
'MiniMax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
|
||||
'302.AI': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
||||
'openai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/openai.svg',
|
||||
'azure': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/azure.svg',
|
||||
'xai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/xai.svg',
|
||||
'anthropic': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/anthropic.svg',
|
||||
'ollama': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ollama.svg',
|
||||
'google': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/gemini-color.svg',
|
||||
'deepseek': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/deepseek.svg',
|
||||
'zhipu': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/zhipu.svg',
|
||||
'siliconflow': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/siliconcloud.svg',
|
||||
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
||||
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
||||
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
||||
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
||||
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
||||
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
||||
'fishaudio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fishaudio.svg',
|
||||
'minimax': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/minimax.svg',
|
||||
'302ai': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/1.53.0/files/icons/ai302-color.svg',
|
||||
'microsoft': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/microsoft.svg',
|
||||
};
|
||||
for (const key in icons) {
|
||||
if (type.startsWith(key)) {
|
||||
return icons[key];
|
||||
}
|
||||
}
|
||||
return ''
|
||||
return icons[type] || '';
|
||||
},
|
||||
|
||||
// 获取Tab类型的中文名称
|
||||
|
||||
1081
dashboard/src/views/SessionManagementPage.vue
Normal file
1081
dashboard/src/views/SessionManagementPage.vue
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,11 +5,8 @@
|
||||
<v-list lines="two">
|
||||
<v-list-subheader>{{ tm('network.title') }}</v-list-subheader>
|
||||
|
||||
<v-list-item :subtitle="tm('network.githubProxy.subtitle')" :title="tm('network.githubProxy.title')">
|
||||
|
||||
<v-combobox variant="outlined" style="width: 100%; margin-top: 16px;" v-model="selectedGitHubProxy" :items="githubProxies"
|
||||
:label="tm('network.githubProxy.label')">
|
||||
</v-combobox>
|
||||
<v-list-item>
|
||||
<ProxySelector></ProxySelector>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-subheader>{{ tm('system.title') }}</v-list-subheader>
|
||||
@@ -17,41 +14,29 @@
|
||||
<v-list-item :subtitle="tm('system.restart.subtitle')" :title="tm('system.restart.title')">
|
||||
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">{{ tm('system.restart.button') }}</v-btn>
|
||||
</v-list-item>
|
||||
|
||||
|
||||
|
||||
|
||||
</v-list>
|
||||
|
||||
</div>
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
|
||||
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
components: {
|
||||
WaitingForRestart,
|
||||
ProxySelector,
|
||||
},
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/settings');
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
githubProxies: [
|
||||
"https://gh.llkk.cc",
|
||||
"https://gitproxy.click",
|
||||
],
|
||||
selectedGitHubProxy: "",
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
restartAstrBot() {
|
||||
axios.post('/api/stat/restart-core').then(() => {
|
||||
@@ -59,16 +44,5 @@ export default {
|
||||
})
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
this.selectedGitHubProxy = localStorage.getItem('selectedGitHubProxy') || "";
|
||||
},
|
||||
watch: {
|
||||
selectedGitHubProxy: function (newVal, oldVal) {
|
||||
if (!newVal) {
|
||||
newVal = ""
|
||||
}
|
||||
localStorage.setItem('selectedGitHubProxy', newVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -20,9 +20,16 @@
|
||||
</v-tooltip>
|
||||
</p>
|
||||
</div>
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showMcpServerDialog = true" rounded="xl" size="x-large">
|
||||
<div>
|
||||
<v-btn color="primary" prepend-icon="mdi-tools" class="me-2" variant="tonal" @click="showToolsDialog = true"
|
||||
rounded="xl" size="x-large">
|
||||
{{ tm('functionTools.buttons.view') }}({{ tools.length }})
|
||||
</v-btn>
|
||||
<v-btn color="success" prepend-icon="mdi-plus" variant="tonal" @click="showMcpServerDialog = true"
|
||||
rounded="xl" size="x-large">
|
||||
{{ tm('mcpServers.buttons.add') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-row>
|
||||
|
||||
<!-- 标签页切换 -->
|
||||
@@ -44,23 +51,7 @@
|
||||
<!-- 本地服务器标签页内容 -->
|
||||
<v-window-item value="local">
|
||||
<!-- MCP 服务器部分 -->
|
||||
<v-card class="mb-6" elevation="2">
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<v-icon color="primary" class="me-2">mdi-server</v-icon>
|
||||
<span class="text-h6">{{ tm('mcpServers.title') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="getServers" :loading="loading">
|
||||
{{ tm('mcpServers.buttons.refresh') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" style="margin-left: 8px;" prepend-icon="mdi-plus" variant="tonal"
|
||||
@click="showMcpServerDialog = true">
|
||||
{{ tm('mcpServers.buttons.add') }}
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-text class="px-4 py-3">
|
||||
<div v-if="mcpServers.length === 0" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-server-off</v-icon>
|
||||
<p class="text-grey mt-4">{{ tm('mcpServers.empty') }}</p>
|
||||
@@ -68,14 +59,8 @@
|
||||
|
||||
<v-row v-else>
|
||||
<v-col v-for="(server, index) in mcpServers || []" :key="index" cols="12" md="6" lg="4" xl="3">
|
||||
<item-card
|
||||
style="background-color: #f7f2f9;"
|
||||
:item="server"
|
||||
title-field="name"
|
||||
enabled-field="active"
|
||||
@toggle-enabled="updateServerStatus"
|
||||
@delete="deleteServer"
|
||||
@edit="editServer">
|
||||
<item-card style="background-color: rgb(var(--v-theme-mcpCardBg));" :item="server" title-field="name" enabled-field="active"
|
||||
@toggle-enabled="updateServerStatus" @delete="deleteServer" @edit="editServer">
|
||||
<template v-slot:item-details="{ item }">
|
||||
<div class="d-flex align-center mb-2">
|
||||
<v-icon size="small" color="grey" class="me-2">mdi-file-code</v-icon>
|
||||
@@ -84,129 +69,61 @@
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="d-flex" style="gap: 8px;">
|
||||
<div>
|
||||
<div v-if="item.tools && item.tools.length > 0">
|
||||
<div class="d-flex align-center mb-1">
|
||||
<v-icon size="small" color="grey" class="me-2">mdi-tools</v-icon>
|
||||
<span class="text-caption text-medium-emphasis">{{ tm('mcpServers.status.availableTools') }} ({{ item.tools.length }})</span>
|
||||
<v-dialog max-width="600px">
|
||||
<template v-slot:activator="{ props: listToolsProps }">
|
||||
<span class="text-caption text-medium-emphasis cursor-pointer" v-bind="listToolsProps"
|
||||
style="text-decoration: underline;">
|
||||
{{ tm('mcpServers.status.availableTools', { count: item.tools.length }) }} ({{
|
||||
item.tools.length }})
|
||||
</span>
|
||||
</template>
|
||||
<template v-slot:default="{ isActive }">
|
||||
<v-card style="padding: 16px;">
|
||||
<v-card-title class="d-flex align-center">
|
||||
<span>{{ tm('mcpServers.status.availableTools') }}</span>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<ul>
|
||||
<li v-for="(tool, idx) in item.tools" :key="idx" style="margin: 8px 0px;">{{
|
||||
tool
|
||||
}}
|
||||
</li>
|
||||
</ul>
|
||||
</v-card-text>
|
||||
<v-card-actions class="d-flex justify-end">
|
||||
<v-btn variant="text" color="primary" @click="isActive.value = false">
|
||||
Close
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</template>
|
||||
|
||||
|
||||
</v-dialog>
|
||||
</div>
|
||||
<v-chip-group class="tool-chips">
|
||||
<v-chip v-for="(tool, idx) in item.tools" :key="idx" size="x-small" density="compact" color="info"
|
||||
class="text-caption">
|
||||
{{ tool }}
|
||||
</v-chip>
|
||||
</v-chip-group>
|
||||
</div>
|
||||
<div v-else class="text-caption text-medium-emphasis mt-2">
|
||||
<div v-else class="text-caption text-medium-emphasis">
|
||||
<v-icon size="small" color="warning" class="me-1">mdi-alert-circle</v-icon>
|
||||
{{ tm('mcpServers.status.noTools') }}
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="mcpServerUpdateLoaders[item.name]" class="text-caption text-medium-emphasis">
|
||||
<v-progress-circular indeterminate color="primary" size="16"></v-progress-circular>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</template>
|
||||
</item-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
<!-- 函数工具部分 -->
|
||||
<v-card elevation="0" class="mt-4">
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<v-icon color="primary" class="me-2">mdi-function</v-icon>
|
||||
<span class="text-h4">{{ tm('functionTools.title') }}</span>
|
||||
<v-chip color="info" size="small" class="ml-2">{{ tools.length }}</v-chip>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" color="primary" @click="showTools = !showTools">
|
||||
{{ showTools ? tm('functionTools.buttons.collapse') : tm('functionTools.buttons.expand') }}
|
||||
<v-icon>{{ showTools ? 'mdi-chevron-up' : 'mdi-chevron-down' }}</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-expand-transition>
|
||||
<v-card-text class="pa-0" v-if="showTools">
|
||||
<div class="pa-4">
|
||||
<div v-if="tools.length === 0" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
|
||||
<p class="text-grey mt-4">{{ tm('functionTools.empty') }}</p>
|
||||
</div>
|
||||
|
||||
<div v-else>
|
||||
<v-text-field v-model="toolSearch" prepend-inner-icon="mdi-magnify" :label="tm('functionTools.search')" variant="outlined"
|
||||
density="compact" class="mb-4" hide-details clearable></v-text-field>
|
||||
|
||||
<v-expansion-panels v-model="openedPanel" multiple>
|
||||
<v-expansion-panel v-for="(tool, index) in filteredTools" :key="index" :value="index"
|
||||
class="mb-2 tool-panel" rounded="lg">
|
||||
<v-expansion-panel-title>
|
||||
<v-row no-gutters align="center">
|
||||
<v-col cols="3">
|
||||
<div class="d-flex align-center">
|
||||
<v-icon color="primary" class="me-2" size="small">
|
||||
{{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
|
||||
</v-icon>
|
||||
<span class="text-body-1 text-high-emphasis font-weight-medium text-truncate"
|
||||
:title="tool.function.name">
|
||||
{{ formatToolName(tool.function.name) }}
|
||||
</span>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="9" class="text-grey">
|
||||
{{ tool.function.description }}
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<v-card flat>
|
||||
<v-card-text>
|
||||
<p class="text-body-1 font-weight-medium mb-3">
|
||||
<v-icon color="primary" size="small" class="me-1">mdi-information</v-icon>
|
||||
{{ tm('functionTools.description') }}
|
||||
</p>
|
||||
<p class="text-body-2 ml-6 mb-4">{{ tool.function.description }}</p>
|
||||
|
||||
<template v-if="tool.function.parameters && tool.function.parameters.properties">
|
||||
<p class="text-body-1 font-weight-medium mb-3">
|
||||
<v-icon color="primary" size="small" class="me-1">mdi-code-json</v-icon>
|
||||
{{ tm('functionTools.parameters') }}
|
||||
</p>
|
||||
|
||||
<v-table density="compact" class="params-table mt-1">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{{ tm('functionTools.table.paramName') }}</th>
|
||||
<th>{{ tm('functionTools.table.type') }}</th>
|
||||
<th>{{ tm('functionTools.table.description') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="(param, paramName) in tool.function.parameters.properties"
|
||||
:key="paramName">
|
||||
<td class="font-weight-medium">{{ paramName }}</td>
|
||||
<td>
|
||||
<v-chip size="x-small" color="primary" text class="text-caption">
|
||||
{{ param.type }}
|
||||
</v-chip>
|
||||
</td>
|
||||
<td>{{ param.description }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</v-table>
|
||||
</template>
|
||||
<div v-else class="text-center pa-4 text-medium-emphasis">
|
||||
<v-icon size="large" color="grey-lighten-1">mdi-code-brackets</v-icon>
|
||||
<p>{{ tm('functionTools.noParameters') }}</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
</v-expansion-panels>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-expand-transition>
|
||||
</v-card>
|
||||
</v-window-item>
|
||||
|
||||
<!-- MCP市场标签页内容 -->
|
||||
@@ -216,9 +133,9 @@
|
||||
<v-icon color="primary" class="me-2">mdi-store</v-icon>
|
||||
<span class="text-h6">{{ tm('marketplace.title') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-text-field v-model="marketplaceSearch" prepend-inner-icon="mdi-magnify" :label="tm('marketplace.search')"
|
||||
variant="outlined" density="compact" hide-details class="mx-2" style="max-width: 300px" clearable
|
||||
@update:model-value="searchMarketplaceServers"></v-text-field>
|
||||
<v-text-field v-model="marketplaceSearch" prepend-inner-icon="mdi-magnify"
|
||||
:label="tm('marketplace.search')" variant="outlined" density="compact" hide-details class="mx-2"
|
||||
style="max-width: 300px" clearable @update:model-value="searchMarketplaceServers"></v-text-field>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="text" @click="fetchMarketplaceServers(1)"
|
||||
:loading="marketplaceLoading">
|
||||
{{ tm('marketplace.buttons.refresh') }}
|
||||
@@ -256,7 +173,8 @@
|
||||
<div class="d-flex align-center mb-2">
|
||||
<v-icon size="small" color="grey" class="me-2">mdi-tools</v-icon>
|
||||
<span class="text-caption text-medium-emphasis">
|
||||
{{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 }) }}
|
||||
{{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 })
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -310,31 +228,25 @@
|
||||
|
||||
<v-card-text class="py-4">
|
||||
<v-form @submit.prevent="saveServer" ref="form">
|
||||
<v-text-field v-model="currentServer.name" :label="tm('dialogs.addServer.fields.name')" variant="outlined" :rules="[v => !!v || tm('dialogs.addServer.fields.nameRequired')]"
|
||||
required class="mb-3"></v-text-field>
|
||||
|
||||
<v-switch v-model="currentServer.active" :label="tm('dialogs.addServer.fields.enable')" color="primary" hide-details class="mb-3"></v-switch>
|
||||
<v-text-field v-model="currentServer.name" :label="tm('dialogs.addServer.fields.name')" variant="outlined"
|
||||
:rules="[v => !!v || tm('dialogs.addServer.fields.nameRequired')]" required class="mb-3"></v-text-field>
|
||||
|
||||
<div class="mb-2 d-flex align-center">
|
||||
<span class="text-subtitle-1">{{ tm('dialogs.addServer.fields.config') }}</span>
|
||||
<v-tooltip location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-icon v-bind="props" class="ms-2" size="small" color="primary">mdi-information</v-icon>
|
||||
</template>
|
||||
<div style="white-space: pre-line;">
|
||||
{{ tm('tooltip.serverConfig') }}
|
||||
</div>
|
||||
</v-tooltip>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn size="small" color="info" variant="text" @click="setConfigTemplate" class="me-1">
|
||||
{{ tm('mcpServers.buttons.useTemplate') }}
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="setConfigTemplate('stdio')" class="me-1">
|
||||
{{ tm('mcpServers.buttons.useTemplateStdio') }}
|
||||
</v-btn>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="setConfigTemplate('streamable_http')"
|
||||
class="me-1">
|
||||
{{ tm('mcpServers.buttons.useTemplateStreamableHttp') }}
|
||||
</v-btn>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="setConfigTemplate('sse')" class="me-1">
|
||||
{{ tm('mcpServers.buttons.useTemplateSse') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
<small>{{ tm('dialogs.addServer.configNotes.note1') }}</small>
|
||||
<br>
|
||||
<small>{{ tm('dialogs.addServer.configNotes.note2') }}</small>
|
||||
|
||||
<div class="monaco-container">
|
||||
<div class="monaco-container" style="margin-top: 16px;">
|
||||
<VueMonacoEditor v-model:value="serverConfigJson" theme="vs-dark" language="json" :options="{
|
||||
minimap: {
|
||||
enabled: false
|
||||
@@ -353,15 +265,20 @@
|
||||
</div>
|
||||
|
||||
</v-form>
|
||||
</v-card-text>
|
||||
<div style="margin-top: 8px;">
|
||||
<small>{{ addServerDialogMessage }}</small>
|
||||
</div>
|
||||
|
||||
<v-divider></v-divider>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="closeServerDialog" :disabled="loading">
|
||||
{{ tm('dialogs.addServer.buttons.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn variant="text" @click="testServerConnection" :disabled="loading">
|
||||
{{ tm('dialogs.addServer.buttons.testConnection') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" @click="saveServer" :loading="loading" :disabled="!isServerFormValid">
|
||||
{{ tm('dialogs.addServer.buttons.save') }}
|
||||
</v-btn>
|
||||
@@ -469,6 +386,106 @@
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 函数工具对话框 -->
|
||||
<v-dialog v-model="showToolsDialog" max-width="800px">
|
||||
<v-card elevation="0" class="mt-4">
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
{{ tm('functionTools.title') }}
|
||||
<v-chip color="info" size="small" class="ml-2">{{ tools.length }}</v-chip>
|
||||
</v-card-title>
|
||||
<v-expand-transition>
|
||||
<v-card-text class="pa-0" v-if="showTools">
|
||||
<div class="pa-4">
|
||||
<div v-if="tools.length === 0" class="text-center pa-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
|
||||
<p class="text-grey mt-4">{{ tm('functionTools.empty') }}</p>
|
||||
</div>
|
||||
|
||||
<div v-else>
|
||||
<v-text-field v-model="toolSearch" prepend-inner-icon="mdi-magnify" :label="tm('functionTools.search')"
|
||||
variant="outlined" density="compact" class="mb-4" hide-details clearable></v-text-field>
|
||||
|
||||
<v-expansion-panels v-model="openedPanel" multiple style="max-height: 500px; overflow-y: auto;">
|
||||
<v-expansion-panel v-for="(tool, index) in filteredTools" :key="index" :value="index"
|
||||
class="mb-2 tool-panel" rounded="lg">
|
||||
<v-expansion-panel-title>
|
||||
<v-row no-gutters align="center">
|
||||
<v-col cols="3">
|
||||
<div class="d-flex align-center">
|
||||
<v-icon color="primary" class="me-2" size="small">
|
||||
{{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
|
||||
</v-icon>
|
||||
<span class="text-body-1 text-high-emphasis font-weight-medium text-truncate"
|
||||
:title="tool.function.name">
|
||||
{{ formatToolName(tool.function.name) }}
|
||||
</span>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="9" class="text-grey">
|
||||
{{ tool.function.description }}
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<v-card flat>
|
||||
<v-card-text>
|
||||
<p class="text-body-1 font-weight-medium mb-3">
|
||||
<v-icon color="primary" size="small" class="me-1">mdi-information</v-icon>
|
||||
{{ tm('functionTools.description') }}
|
||||
</p>
|
||||
<p class="text-body-2 ml-6 mb-4">{{ tool.function.description }}</p>
|
||||
|
||||
<template v-if="tool.function.parameters && tool.function.parameters.properties">
|
||||
<p class="text-body-1 font-weight-medium mb-3">
|
||||
<v-icon color="primary" size="small" class="me-1">mdi-code-json</v-icon>
|
||||
{{ tm('functionTools.parameters') }}
|
||||
</p>
|
||||
|
||||
<v-table density="compact" class="params-table mt-1">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{{ tm('functionTools.table.paramName') }}</th>
|
||||
<th>{{ tm('functionTools.table.type') }}</th>
|
||||
<th>{{ tm('functionTools.table.description') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="(param, paramName) in tool.function.parameters.properties" :key="paramName">
|
||||
<td class="font-weight-medium">{{ paramName }}</td>
|
||||
<td>
|
||||
<v-chip size="x-small" color="primary" text class="text-caption">
|
||||
{{ param.type }}
|
||||
</v-chip>
|
||||
</td>
|
||||
<td>{{ param.description }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</v-table>
|
||||
</template>
|
||||
<div v-else class="text-center pa-4 text-medium-emphasis">
|
||||
<v-icon size="large" color="grey-lighten-1">mdi-code-brackets</v-icon>
|
||||
<p>{{ tm('functionTools.noParameters') }}</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
</v-expansion-panels>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-expand-transition>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="showToolsDialog = false">
|
||||
{{ tm('dialogs.serverDetail.buttons.close') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack"
|
||||
location="top">
|
||||
@@ -504,8 +521,12 @@ export default {
|
||||
tools: [],
|
||||
showMcpServerDialog: false,
|
||||
showServerDetailDialog: false,
|
||||
addServerDialogMessage: "",
|
||||
showToolsDialog: false,
|
||||
showTools: true,
|
||||
loading: false,
|
||||
loadingGettingServers: false,
|
||||
mcpServerUpdateLoaders: {}, // record loading state for each server update
|
||||
isEditMode: false,
|
||||
serverConfigJson: '',
|
||||
jsonError: null,
|
||||
@@ -618,17 +639,21 @@ export default {
|
||||
},
|
||||
|
||||
getServers() {
|
||||
this.loading = true
|
||||
this.loadingGettingServers = true;
|
||||
axios.get('/api/tools/mcp/servers')
|
||||
.then(response => {
|
||||
this.mcpServers = response.data.data || [];
|
||||
this.mcpServers.forEach(server => {
|
||||
// Ensure each server has a loader state
|
||||
if (!this.mcpServerUpdateLoaders[server.name]) {
|
||||
this.mcpServerUpdateLoaders[server.name] = false;
|
||||
}
|
||||
});
|
||||
})
|
||||
.catch(error => {
|
||||
this.showError(this.tm('messages.getServersError', { error: error.message }));
|
||||
}).finally(() => {
|
||||
setTimeout(() => {
|
||||
this.loading = false;
|
||||
}, 500);
|
||||
this.loadingGettingServers = false;
|
||||
});
|
||||
},
|
||||
|
||||
@@ -658,14 +683,28 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
setConfigTemplate() {
|
||||
// 设置一个基本的配置模板
|
||||
const template = {
|
||||
setConfigTemplate(type = 'stdio') {
|
||||
let template = {};
|
||||
if (type === 'streamable_http') {
|
||||
template = {
|
||||
transport: "streamable_http",
|
||||
url: "your mcp server url",
|
||||
headers: {},
|
||||
timeout: 30,
|
||||
};
|
||||
} else if (type === 'sse') {
|
||||
template = {
|
||||
transport: "sse",
|
||||
url: "your mcp server url",
|
||||
headers: {},
|
||||
timeout: 30,
|
||||
};
|
||||
} else {
|
||||
template = {
|
||||
command: "python",
|
||||
args: ["-m", "your_module"],
|
||||
// 可以添加其他 MCP 支持的配置项
|
||||
};
|
||||
|
||||
}
|
||||
this.serverConfigJson = JSON.stringify(template, null, 2);
|
||||
},
|
||||
|
||||
@@ -693,6 +732,7 @@ export default {
|
||||
.then(response => {
|
||||
this.loading = false;
|
||||
this.showMcpServerDialog = false;
|
||||
this.addServerDialogMessage = "";
|
||||
this.getServers();
|
||||
this.getTools();
|
||||
this.showSuccess(response.data.message || this.tm('messages.saveSuccess'));
|
||||
@@ -753,6 +793,7 @@ export default {
|
||||
|
||||
updateServerStatus(server) {
|
||||
// 切换服务器状态
|
||||
this.mcpServerUpdateLoaders[server.name] = true;
|
||||
server.active = !server.active;
|
||||
axios.post('/api/tools/mcp/update', server)
|
||||
.then(response => {
|
||||
@@ -761,16 +802,48 @@ export default {
|
||||
})
|
||||
.catch(error => {
|
||||
this.showError(this.tm('messages.updateError', { error: error.response?.data?.message || error.message }));
|
||||
// 回滚状态
|
||||
server.active = !server.active;
|
||||
})
|
||||
.finally(() => {
|
||||
this.mcpServerUpdateLoaders[server.name] = false;
|
||||
});
|
||||
},
|
||||
|
||||
closeServerDialog() {
|
||||
this.showMcpServerDialog = false;
|
||||
this.addServerDialogMessage = '';
|
||||
this.resetForm();
|
||||
},
|
||||
|
||||
testServerConnection() {
|
||||
if (!this.validateJson()) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.loading = true;
|
||||
|
||||
let configObj;
|
||||
try {
|
||||
configObj = JSON.parse(this.serverConfigJson);
|
||||
} catch (e) {
|
||||
this.loading = false;
|
||||
this.showError(this.tm('dialogs.addServer.errors.jsonParse', { error: e.message }));
|
||||
return;
|
||||
}
|
||||
|
||||
axios.post('/api/tools/mcp/test', {
|
||||
"mcp_server_config": configObj,
|
||||
})
|
||||
.then(response => {
|
||||
this.loading = false;
|
||||
this.addServerDialogMessage = `${response.data.message} (tools: ${response.data.data})`;
|
||||
})
|
||||
.catch(error => {
|
||||
this.loading = false;
|
||||
this.showError(this.tm('messages.testError', { error: error.response?.data?.message || error.message }));
|
||||
});
|
||||
},
|
||||
|
||||
resetForm() {
|
||||
this.currentServer = {
|
||||
name: '',
|
||||
@@ -939,7 +1012,7 @@ export default {
|
||||
|
||||
.monaco-container {
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: 4px;
|
||||
border-radius: 8px;
|
||||
height: 300px;
|
||||
margin-top: 4px;
|
||||
overflow: hidden;
|
||||
|
||||
@@ -135,7 +135,7 @@ class LongTermMemory:
|
||||
return
|
||||
|
||||
if event.get_result() and event.get_result().is_llm_result():
|
||||
final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import aiohttp
|
||||
import datetime
|
||||
import builtins
|
||||
@@ -16,7 +15,6 @@ from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
@@ -27,6 +25,7 @@ from astrbot.core.config.default import VERSION
|
||||
from .long_term_memory import LongTermMemory
|
||||
from astrbot.core import logger
|
||||
from astrbot.api.message_components import Plain, Image, Reply
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from typing import Union
|
||||
from enum import Enum
|
||||
|
||||
@@ -335,16 +334,18 @@ class Main(star.Star):
|
||||
|
||||
@filter.command("tts")
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音"""
|
||||
config = self.context.get_config()
|
||||
if config["provider_tts_settings"]["enable"]:
|
||||
config["provider_tts_settings"]["enable"] = False
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已关闭文本转语音。"))
|
||||
return
|
||||
config["provider_tts_settings"]["enable"] = True
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转语音。"))
|
||||
"""开关文本转语音(会话级别)"""
|
||||
session_id = event.unified_msg_origin
|
||||
current_status = SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
# 切换状态
|
||||
new_status = not current_status
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, new_status)
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{status_text}当前会话的文本转语音。")
|
||||
)
|
||||
|
||||
@filter.command("sid")
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
@@ -1150,24 +1151,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
sp.put("session_variables", session_vars)
|
||||
yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。")
|
||||
|
||||
@filter.command("gewe_logout")
|
||||
async def gewe_logout(self, event: AstrMessageEvent):
|
||||
platforms = self.context.platform_manager.platform_insts
|
||||
for platform in platforms:
|
||||
if platform.meta().name == "gewechat":
|
||||
yield event.plain_result("正在登出 gewechat")
|
||||
await platform.logout()
|
||||
yield event.plain_result("已登出 gewechat,请重启 AstrBot")
|
||||
return
|
||||
|
||||
@filter.command("gewe_code")
|
||||
async def gewe_code(self, event: AstrMessageEvent, code: str):
|
||||
"""保存 gewechat 验证码"""
|
||||
code_path = os.path.join(get_astrbot_data_path(), "temp", "gewe_code")
|
||||
with open(code_path, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
yield event.plain_result("验证码已保存。")
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
"""群聊记忆增强"""
|
||||
@@ -1239,6 +1222,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
@filter.on_decorating_result()
|
||||
async def decorate_result(self, event: AstrMessageEvent):
|
||||
logger.debug("Decorating result for event: %s", event)
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
|
||||
@@ -129,9 +129,9 @@ class Main(star.Star):
|
||||
logger.info(
|
||||
"Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。"
|
||||
)
|
||||
await self.context._star_manager.turn_off_plugin(
|
||||
"astrbot-python-interpreter"
|
||||
)
|
||||
# await self.context._star_manager.turn_off_plugin(
|
||||
# "astrbot-python-interpreter"
|
||||
# )
|
||||
|
||||
async def file_upload(self, file_path: str):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "3.5.21"
|
||||
version = "3.5.23"
|
||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -37,7 +37,6 @@ watchfiles
|
||||
websockets
|
||||
faiss-cpu
|
||||
aiosqlite
|
||||
nh3
|
||||
py-cord>=2.6.1
|
||||
slack-sdk
|
||||
pydub
|
||||
BIN
samples/stt_health_check.wav
Normal file
BIN
samples/stt_health_check.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user