Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
54340cca18 stage 2025-10-10 19:41:18 +08:00
49 changed files with 460 additions and 2431 deletions

View File

@@ -11,8 +11,6 @@ reviewers:
- Larch-C
- anka-afk
- advent259141
- Fridemn
- LIghtJUNction
# - zouyonghe
# A number of reviewers added to the pull request

View File

@@ -60,7 +60,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -88,6 +88,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

106
README.md
View File

@@ -1,3 +1,5 @@
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
@@ -11,17 +13,17 @@
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
## 主要功能
@@ -33,7 +35,7 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
## 部署方式
#### Docker 部署(推荐 🥳)
#### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
@@ -99,6 +101,7 @@ uv run main.py
- 5 群822130018
- 6 群753075035
- 开发者群975206796
- 开发者群备份295657329
### Telegram 群组
@@ -110,80 +113,48 @@ uv run main.py
## ⚡ 消息平台支持情况
**官方维护**
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方平台) | ✔ |
| QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企微应用 | ✔ |
| 企业微信 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| 企微智能机器人 | 将支持 |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
## ⚡ 提供商支持情况
**大模型服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | 文本生成 | |
| Google Gemini | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
## ❤️ 贡献
@@ -215,10 +186,19 @@ pre-commit install
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
另外,一些同类型其他的活跃开源 Bot 项目:
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
- [KroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
## ⭐ Star History
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">

View File

@@ -40,15 +40,8 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
timeout = cfg.get("timeout", 10)
try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http":
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
@@ -128,14 +121,7 @@ class MCPClient:
if not success:
raise Exception(error_msg)
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
if transport_type != "streamable_http":
if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
@@ -148,7 +134,7 @@ class MCPClient:
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
@@ -173,7 +159,7 @@ class MCPClient:
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,

View File

@@ -9,4 +9,3 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -6,7 +6,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.3.5"
VERSION = "4.3.2"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -57,7 +57,6 @@ DEFAULT_CONFIG = {
"web_search": False,
"websearch_provider": "default",
"websearch_tavily_key": [],
"websearch_baidu_app_builder_key": "",
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
@@ -72,7 +71,6 @@ DEFAULT_CONFIG = {
"show_tool_use_status": False,
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
"provider_stt_settings": {
"enable": False,
@@ -209,18 +207,6 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6195,
},
"企业微信智能机器人": {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecomaibot_init_respond_text": "💭 思考中...",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"token": "",
"encoding_aes_key": "",
"callback_server_host": "0.0.0.0",
"port": 6198,
},
"飞书(Lark)": {
"id": "lark",
"type": "lark",
@@ -461,25 +447,10 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
},
"wecom_ai_bot_name": {
"description": "企业微信智能机器人的名字",
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
"hint": "请务必填,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
},
"discord_token": {
"description": "Discord Bot Token",
@@ -1085,7 +1056,6 @@ CONFIG_METADATA_2 = {
"timeout": "20",
},
"阿里云百炼 TTS(API)": {
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
"id": "dashscope_tts",
"provider": "dashscope",
"type": "dashscope_tts",
@@ -1465,7 +1435,11 @@ CONFIG_METADATA_2 = {
"description": "服务订阅密钥",
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
},
"dashscope_tts_voice": {"description": "音色", "type": "string"},
"dashscope_tts_voice": {
"description": "语音合成模型",
"type": "string",
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
},
"gm_resp_image_modal": {
"description": "启用图片模态",
"type": "bool",
@@ -1874,10 +1848,6 @@ CONFIG_METADATA_2 = {
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
},
},
"provider_stt_settings": {
@@ -2096,7 +2066,7 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": {
"description": "网页搜索提供商",
"type": "string",
"options": ["default", "tavily", "baidu_ai_search"],
"options": ["default", "tavily"],
},
"provider_settings.websearch_tavily_key": {
"description": "Tavily API Key",
@@ -2107,14 +2077,6 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": "tavily",
},
},
"provider_settings.websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"type": "string",
"hint": "参考https://console.bce.baidu.com/iam/#/iam/apikey/list",
"condition": {
"provider_settings.websearch_provider": "baidu_ai_search",
},
},
"provider_settings.web_search_link": {
"description": "显示来源引用",
"type": "bool",
@@ -2150,10 +2112,6 @@ CONFIG_METADATA_3 = {
"description": "工具调用轮数上限",
"type": "int",
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
"provider_settings.streaming_response": {
"description": "流式回复",
"type": "bool",

View File

@@ -6,7 +6,6 @@ import asyncio
import copy
import json
import traceback
from datetime import timedelta
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
@@ -186,33 +185,21 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
handler=awaitable,
**tool_args,
)
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
)
except StopAsyncIteration:
break
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
@classmethod
async def _execute_mcp(
@@ -230,9 +217,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout
),
)
if not res:
return
@@ -323,7 +307,6 @@ class LLMRequestSubStage(Stage):
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
@@ -490,7 +473,6 @@ class LLMRequestSubStage(Stage):
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
tool_call_timeout=self.tool_call_timeout,
)
await agent_runner.reset(
provider=provider,

View File

@@ -74,7 +74,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
if event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -82,10 +82,6 @@ class PlatformManager:
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter, # noqa: F401
)
case "wecom_ai_bot":
from .sources.wecom_ai_bot.wecomai_adapter import (
WecomAIBotAdapter, # noqa: F401
)
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa: F401

View File

@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
from astrbot.api.message_components import Plain, Image, At, File, Record
if TYPE_CHECKING:
from .satori_adapter import SatoriPlatformAdapter
@@ -87,17 +87,6 @@ class SatoriPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
content_parts.append(f'<reply id="{component.id}"/>')
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
content_parts.append(f'<video src="{video_path_url}"/>')
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
content = "".join(content_parts)
channel_id = session_id
data = {"channel_id": channel_id, "content": content}
@@ -177,17 +166,6 @@ class SatoriPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
content_parts.append(f'<reply id="{component.id}"/>')
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
content_parts.append(f'<video src="{video_path_url}"/>')
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
content = "".join(content_parts)
channel_id = self.session_id
data = {"channel_id": channel_id, "content": content}

View File

@@ -91,6 +91,7 @@ class WebChatAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE

View File

@@ -1,289 +0,0 @@
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
"""对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2020 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
import random
import hashlib
import time
import struct
from Crypto.Cipher import AES
import socket
import json
from . import ierror
"""
关于Crypto.Cipher模块ImportError: No module named 'Crypto'解决方案
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
下载后按照README中的“Installation”小节的提示进行pycrypto安装。
"""
class FormatException(Exception):
pass
def throw_exception(message, exception_class=FormatException):
"""my define raise exception function"""
raise exception_class(message)
class SHA1:
"""计算企业微信的消息签名接口"""
def getSHA1(self, token, timestamp, nonce, encrypt):
"""用SHA1算法生成安全签名
@param token: 票据
@param timestamp: 时间戳
@param encrypt: 密文
@param nonce: 随机字符串
@return: 安全签名
"""
try:
# 确保所有输入都是字符串类型
if isinstance(encrypt, bytes):
encrypt = encrypt.decode("utf-8")
sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode("utf-8"))
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
class JsonParse:
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
# json消息模板
AES_TEXT_RESPONSE_TEMPLATE = """{
"encrypt": "%(msg_encrypt)s",
"msgsignature": "%(msg_signaturet)s",
"timestamp": "%(timestamp)s",
"nonce": "%(nonce)s"
}"""
def extract(self, jsontext):
"""提取出json数据包中的加密消息
@param jsontext: 待提取的json字符串
@return: 提取出的加密消息字符串
"""
try:
json_dict = json.loads(jsontext)
return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"]
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_ParseJson_Error, None
def generate(self, encrypt, signature, timestamp, nonce):
"""生成json消息
@param encrypt: 加密后的消息密文
@param signature: 安全签名
@param timestamp: 时间戳
@param nonce: 随机字符串
@return: 生成的json字符串
"""
resp_dict = {
"msg_encrypt": encrypt,
"msg_signaturet": signature,
"timestamp": timestamp,
"nonce": nonce,
}
resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
return resp_json
class PKCS7Encoder:
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
"""对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文(bytes类型)
@return: 补齐明文字符串(bytes类型)
"""
text_length = len(text)
# 计算需要填充的位数
amount_to_pad = self.block_size - (text_length % self.block_size)
if amount_to_pad == 0:
amount_to_pad = self.block_size
# 获得补位所用的字符
pad = bytes([amount_to_pad])
# 确保text是bytes类型
if isinstance(text, str):
text = text.encode("utf-8")
return text + pad * amount_to_pad
def decode(self, decrypted):
"""删除解密后明文的补位字符
@param decrypted: 解密后的明文
@return: 删除补位字符后的明文
"""
pad = ord(decrypted[-1])
if pad < 1 or pad > 32:
pad = 0
return decrypted[:-pad]
class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
self.mode = AES.MODE_CBC
def encrypt(self, text, receiveid):
"""对明文进行加密
@param text: 需要加密的明文
@return: 加密得到的字符串
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = (
self.get_random_str()
+ struct.pack("I", socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
text = pkcs7.encode(text)
# 加密
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
try:
ciphertext = cryptor.encrypt(text)
# 使用BASE64对加密后的字符串进行编码
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
except Exception as e:
logger = logging.getLogger("astrbot")
logger.error(e)
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
def decrypt(self, text, receiveid):
"""对解密后的明文进行补位删除
@param text: 密文
@return: 删除填充补位后的明文
"""
try:
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
# 使用BASE64对密文进行解码然后AES-CBC解密
plain_text = cryptor.decrypt(base64.b64decode(text))
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
try:
pad = plain_text[-1]
# 去掉补位字符串
# pkcs7 = PKCS7Encoder()
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
json_len = socket.ntohl(struct.unpack("I", content[:4])[0])
json_content = content[4 : json_len + 4].decode("utf-8")
from_receiveid = content[json_len + 4 :].decode("utf-8")
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_IllegalBuffer, None
if from_receiveid != receiveid:
print("receiveid not match", receiveid, from_receiveid)
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
return 0, json_content
def get_random_str(self):
"""随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizJsonMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
assert len(self.key) == 32
except Exception as e:
throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId
# 验证URL
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sEchoStr: 随机串对应URL参数的echostr
# @param sReplyEchoStr: 解密之后的echostr当return返回0时有效
# @return成功0失败返回对应的错误码
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
return ret, sReplyEchoStr
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
# 将企业回复用户的消息加密打包
# @param sReplyMsg: 企业号待回复用户的消息json格式的字符串
# @param sTimeStamp: 时间戳可以自己生成也可以用URL参数的timestamp,如为None则自动用当前时间
# @param sNonce: 随机串可以自己生成也可以用URL参数的nonce
# sEncryptMsg: 加密后的可以直接回复用户的密文包括msg_signature, timestamp, nonce, encrypt的json格式的字符串,
# return成功0sEncryptMsg,失败返回对应的错误码None
pc = Prpcrypt(self.key)
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
encrypt = encrypt.decode("utf-8") # type: ignore
if ret != 0:
return ret, None
if timestamp is None:
timestamp = str(int(time.time()))
# 生成安全签名
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
if ret != 0:
return ret, None
jsonParse = JsonParse()
return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce)
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
# 检验消息的真实性,并且获取解密后的明文
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sPostData: 密文对应POST请求的数据
# json_content: 解密后的原文当return返回0时有效
# @return: 成功0失败返回对应的错误码
# 验证安全签名
jsonParse = JsonParse()
ret, encrypt = jsonParse.extract(sPostData)
if ret != 0:
return ret, None
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
print("signature not match")
print(signature)
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId)
return ret, json_content

View File

@@ -1,17 +0,0 @@
"""
企业微信智能机器人平台适配器包
"""
from .wecomai_adapter import WecomAIBotAdapter
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import WecomAIBotConstants
__all__ = [
"WecomAIBotAdapter",
"WecomAIBotAPIClient",
"WecomAIBotMessageEvent",
"WecomAIBotServer",
"WecomAIBotConstants",
]

View File

@@ -1,20 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseJson_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnJson_Error = -40011

View File

@@ -1,445 +0,0 @@
"""
企业微信智能机器人平台适配器
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
参考webchat_adapter.py的队列机制实现异步消息处理和流式响应
"""
import time
import asyncio
import uuid
import hashlib
import base64
from typing import Awaitable, Any, Dict, Optional, Callable
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, At, Image
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .wecomai_api import (
WecomAIBotAPIClient,
WecomAIBotMessageParser,
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_server import WecomAIBotServer
from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr
from .wecomai_utils import (
WecomAIBotConstants,
format_session_id,
generate_random_string,
process_encrypted_image,
)
class WecomAIQueueListener:
"""企业微信智能机器人队列监听器参考webchat的QueueListener设计"""
def __init__(
self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]]
) -> None:
self.queue_mgr = queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, session_id: str):
"""监听特定会话的队列"""
queue = self.queue_mgr.get_or_create_queue(session_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
break
async def run(self):
"""监控新会话队列并启动监听器"""
monitored_sessions = set()
while True:
# 检查新会话
current_sessions = set(self.queue_mgr.queues.keys())
new_sessions = current_sessions - monitored_sessions
# 为新会话启动监听器
for session_id in new_sessions:
task = asyncio.create_task(self.listen_to_queue(session_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_sessions.add(session_id)
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
# 清理已不存在的会话
removed_sessions = monitored_sessions - current_sessions
monitored_sessions -= removed_sessions
# 清理过期的待处理响应
self.queue_mgr.cleanup_expired_responses()
await asyncio.sleep(1) # 每秒检查一次新会话
@register_platform_adapter(
"wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息"
)
class WecomAIBotAdapter(Platform):
"""企业微信智能机器人适配器"""
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
# 初始化配置参数
self.token = self.config["token"]
self.encoding_aes_key = self.config["encoding_aes_key"]
self.port = int(self.config["port"])
self.host = self.config.get("callback_server_host", "0.0.0.0")
self.bot_name = self.config.get("wecom_ai_bot_name", "")
self.initial_respond_text = self.config.get(
"wecomaibot_init_respond_text", "💭 思考中..."
)
self.friend_message_welcome_text = self.config.get(
"wecomaibot_friend_message_welcome_text", ""
)
# 平台元数据
self.metadata = PlatformMetadata(
name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
id=self.config.get("id", "wecom_ai_bot"),
)
# 初始化 API 客户端
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
# 初始化 HTTP 服务器
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
# 队列监听器
self.queue_listener = WecomAIQueueListener(
wecomai_queue_mgr, self._handle_queued_message
)
async def _handle_queued_message(self, data: dict):
"""处理队列中的消息类似webchat的callback"""
try:
abm = await self.convert_message(data)
await self.handle_msg(abm)
except Exception as e:
logger.error(f"处理队列消息时发生异常: {e}")
async def _process_message(
self, message_data: Dict[str, Any], callback_params: Dict[str, str]
) -> Optional[str]:
"""处理接收到的消息
Args:
message_data: 解密后的消息数据
callback_params: 回调参数 (nonce, timestamp)
Returns:
加密后的响应消息,无需响应时返回 None
"""
msgtype = message_data.get("msgtype")
if not msgtype:
logger.warning(f"消息类型未知,忽略: {message_data}")
return None
session_id = self._extract_session_id(message_data)
if msgtype in ("text", "image", "mixed"):
# user sent a text / image / mixed message
try:
# create a brand-new unique stream_id for this message session
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._enqueue_message(
message_data, callback_params, stream_id, session_id
)
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id, self.initial_respond_text, False
)
return await self.api_client.encrypt_message(
resp, callback_params["nonce"], callback_params["timestamp"]
)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return None
elif msgtype == "stream":
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not wecomai_queue_mgr.has_back_queue(stream_id):
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
# 返回结束标志,告诉微信服务器流已结束
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id, "", True
)
resp = await self.api_client.encrypt_message(
end_message,
callback_params["nonce"],
callback_params["timestamp"],
)
return resp
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if queue.empty():
logger.debug(
f"No new messages in back queue for stream_id: {stream_id}"
)
return None
# aggregate all delta chains in the back queue
latest_plain_content = ""
image_base64 = []
finish = False
while not queue.empty():
msg = await queue.get()
if msg["type"] == "plain":
latest_plain_content = msg["data"] or ""
elif msg["type"] == "image":
image_base64.append(msg["image_data"])
elif msg["type"] == "end":
# stream end
finish = True
wecomai_queue_mgr.remove_queues(stream_id)
break
else:
pass
logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}"
)
if latest_plain_content or image_base64:
msg_items = []
if finish and image_base64:
for img_b64 in image_base64:
# get md5 of image
img_data = base64.b64decode(img_b64)
img_md5 = hashlib.md5(img_data).hexdigest()
msg_items.append(
{
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
"image": {"base64": img_b64, "md5": img_md5},
}
)
image_base64 = []
plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream(
stream_id, latest_plain_content, msg_items, finish
)
encrypted_message = await self.api_client.encrypt_message(
plain_message,
callback_params["nonce"],
callback_params["timestamp"],
)
if encrypted_message:
logger.debug(
f"Stream message sent successfully, stream_id: {stream_id}"
)
else:
logger.error("消息加密失败")
return encrypted_message
return None
elif msgtype == "event":
event = message_data.get("event")
if event == "enter_chat" and self.friend_message_welcome_text:
# 用户进入会话,发送欢迎消息
try:
resp = WecomAIBotStreamMessageBuilder.make_text(
self.friend_message_welcome_text
)
return await self.api_client.encrypt_message(
resp,
callback_params["nonce"],
callback_params["timestamp"],
)
except Exception as e:
logger.error("处理欢迎消息时发生异常: %s", e)
return None
pass
def _extract_session_id(self, message_data: Dict[str, Any]) -> str:
"""从消息数据中提取会话ID"""
user_id = message_data.get("from", {}).get("userid", "default_user")
return format_session_id("wecomai", user_id)
async def _enqueue_message(
self,
message_data: Dict[str, Any],
callback_params: Dict[str, str],
stream_id: str,
session_id: str,
):
"""将消息放入队列进行异步处理"""
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
message_payload = {
"message_data": message_data,
"callback_params": callback_params,
"session_id": session_id,
"stream_id": stream_id,
}
await input_queue.put(message_payload)
logger.debug(f"[WecomAI] 消息已入队: {stream_id}")
async def convert_message(self, payload: dict) -> AstrBotMessage:
"""转换队列中的消息数据为AstrBotMessage类似webchat的convert_message"""
message_data = payload["message_data"]
session_id = payload["session_id"]
# callback_params = payload["callback_params"] # 保留但暂时不使用
# 解析消息内容
msgtype = message_data.get("msgtype")
content = ""
image_base64 = []
_img_url_to_process = []
msg_items = []
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
content = WecomAIBotMessageParser.parse_text_message(message_data)
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
_img_url_to_process.append(
WecomAIBotMessageParser.parse_image_message(message_data)
)
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
# 提取混合消息中的文本内容
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
text_parts = []
for item in msg_items or []:
if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT:
text_content = item.get("text", {}).get("content", "")
if text_content:
text_parts.append(text_content)
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_url = item.get("image", {}).get("url", "")
if image_url:
_img_url_to_process.append(image_url)
content = " ".join(text_parts) if text_parts else ""
else:
content = f"[{msgtype}消息]"
# 并行处理图片下载和解密
if _img_url_to_process:
tasks = [
process_encrypted_image(url, self.encoding_aes_key)
for url in _img_url_to_process
]
results = await asyncio.gather(*tasks)
for success, result in results:
if success:
image_base64.append(result)
else:
logger.error(f"处理加密图片失败: {result}")
# 构建 AstrBotMessage
abm = AstrBotMessage()
abm.self_id = self.bot_name
abm.message_str = content or "[未知消息]"
abm.message_id = str(uuid.uuid4())
abm.timestamp = int(time.time())
abm.raw_message = payload
# 发送者信息
abm.sender = MessageMember(
user_id=message_data.get("from", {}).get("userid", "unknown"),
nickname=message_data.get("from", {}).get("userid", "unknown"),
)
# 消息类型
abm.type = (
MessageType.GROUP_MESSAGE
if message_data.get("chattype") == "group"
else MessageType.FRIEND_MESSAGE
)
abm.session_id = session_id
# 消息内容
abm.message = []
# 处理 At
if self.bot_name and f"@{self.bot_name}" in abm.message_str:
abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip()
abm.message.append(At(qq=self.bot_name, name=self.bot_name))
abm.message.append(Plain(abm.message_str))
if image_base64:
for img_b64 in image_base64:
abm.message.append(Image.fromBase64(img_b64))
logger.debug(f"WecomAIAdapter: {abm.message}")
return abm
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
"""通过会话发送消息"""
# 企业微信智能机器人主要通过回调响应,这里记录日志
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
await super().send_by_session(session, message_chain)
def run(self) -> Awaitable[Any]:
"""运行适配器同时启动HTTP服务器和队列监听器"""
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
async def run_both():
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
return run_both()
async def terminate(self):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.server.shutdown()
def meta(self) -> PlatformMetadata:
"""获取平台元数据"""
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
"""处理消息,创建消息事件并提交到事件队列"""
try:
message_event = WecomAIBotMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
api_client=self.api_client,
)
self.commit_event(message_event)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
def get_client(self) -> WecomAIBotAPIClient:
"""获取 API 客户端"""
return self.api_client
def get_server(self) -> WecomAIBotServer:
"""获取 HTTP 服务器实例"""
return self.server

View File

@@ -1,378 +0,0 @@
"""
企业微信智能机器人 API 客户端
处理消息加密解密、API 调用等
"""
import json
import base64
import hashlib
from typing import Dict, Any, Optional, Tuple, Union
from Crypto.Cipher import AES
import aiohttp
from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt
from .wecomai_utils import WecomAIBotConstants
from astrbot import logger
class WecomAIBotAPIClient:
"""企业微信智能机器人 API 客户端"""
def __init__(self, token: str, encoding_aes_key: str):
"""初始化 API 客户端
Args:
token: 企业微信机器人 Token
encoding_aes_key: 消息加密密钥
"""
self.token = token
self.encoding_aes_key = encoding_aes_key
self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串
async def decrypt_message(
self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str
) -> Tuple[int, Optional[Dict[str, Any]]]:
"""解密企业微信消息
Args:
encrypted_data: 加密的消息数据
msg_signature: 消息签名
timestamp: 时间戳
nonce: 随机数
Returns:
(错误码, 解密后的消息数据字典)
"""
try:
ret, decrypted_msg = self.wxcpt.DecryptMsg(
encrypted_data, msg_signature, timestamp, nonce
)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"消息解密失败,错误码: {ret}")
return ret, None
# 解析 JSON
if decrypted_msg:
try:
message_data = json.loads(decrypted_msg)
logger.debug(f"解密成功,消息内容: {message_data}")
return WecomAIBotConstants.SUCCESS, message_data
except json.JSONDecodeError as e:
logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}")
return WecomAIBotConstants.PARSE_XML_ERROR, None
else:
logger.error("解密消息为空")
return WecomAIBotConstants.DECRYPT_ERROR, None
except Exception as e:
logger.error(f"解密过程发生异常: {e}")
return WecomAIBotConstants.DECRYPT_ERROR, None
async def encrypt_message(
self, plain_message: str, nonce: str, timestamp: str
) -> Optional[str]:
"""加密消息
Args:
plain_message: 明文消息
nonce: 随机数
timestamp: 时间戳
Returns:
加密后的消息,失败时返回 None
"""
try:
ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"消息加密失败,错误码: {ret}")
return None
logger.debug("消息加密成功")
return encrypted_msg
except Exception as e:
logger.error(f"加密过程发生异常: {e}")
return None
def verify_url(
self, msg_signature: str, timestamp: str, nonce: str, echostr: str
) -> str:
"""验证回调 URL
Args:
msg_signature: 消息签名
timestamp: 时间戳
nonce: 随机数
echostr: 验证字符串
Returns:
验证结果字符串
"""
try:
ret, echo_result = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"URL 验证失败,错误码: {ret}")
return "verify fail"
logger.info("URL 验证成功")
return echo_result if echo_result else "verify fail"
except Exception as e:
logger.error(f"URL 验证发生异常: {e}")
return "verify fail"
async def process_encrypted_image(
self, image_url: str, aes_key_base64: Optional[str] = None
) -> Tuple[bool, Union[bytes, str]]:
"""下载并解密加密图片
Args:
image_url: 加密图片的 URL
aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥
Returns:
(是否成功, 图片数据或错误信息)
"""
try:
# 下载图片
logger.info(f"开始下载加密图片: {image_url}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=15) as response:
if response.status != 200:
error_msg = f"图片下载失败,状态码: {response.status}"
logger.error(error_msg)
return False, error_msg
encrypted_data = await response.read()
logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节")
# 准备解密密钥
if aes_key_base64 is None:
aes_key_base64 = self.encoding_aes_key
if not aes_key_base64:
raise ValueError("AES 密钥不能为空")
# Base64 解码密钥
aes_key = base64.b64decode(
aes_key_base64 + "=" * (-len(aes_key_base64) % 4)
)
if len(aes_key) != 32:
raise ValueError("无效的 AES 密钥长度: 应为 32 字节")
iv = aes_key[:16] # 初始向量为密钥前 16 字节
# 解密图片数据
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = cipher.decrypt(encrypted_data)
# 去除 PKCS#7 填充
pad_len = decrypted_data[-1]
if pad_len > 32: # AES-256 块大小为 32 字节
raise ValueError("无效的填充长度 (大于32字节)")
decrypted_data = decrypted_data[:-pad_len]
logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节")
return True, decrypted_data
except aiohttp.ClientError as e:
error_msg = f"图片下载失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
except ValueError as e:
error_msg = f"参数错误: {str(e)}"
logger.error(error_msg)
return False, error_msg
except Exception as e:
error_msg = f"图片处理异常: {str(e)}"
logger.error(error_msg)
return False, error_msg
class WecomAIBotStreamMessageBuilder:
"""企业微信智能机器人流消息构建器"""
@staticmethod
def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str:
"""构建文本流消息
Args:
stream_id: 流 ID
content: 文本内容
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {"id": stream_id, "finish": finish, "content": content},
}
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_image_stream(
stream_id: str, image_data: bytes, finish: bool = False
) -> str:
"""构建图片流消息
Args:
stream_id: 流 ID
image_data: 图片二进制数据
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
image_md5 = hashlib.md5(image_data).hexdigest()
image_base64 = base64.b64encode(image_data).decode("utf-8")
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {
"id": stream_id,
"finish": finish,
"msg_item": [
{
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
"image": {"base64": image_base64, "md5": image_md5},
}
],
},
}
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_mixed_stream(
stream_id: str, content: str, msg_items: list, finish: bool = False
) -> str:
"""构建混合类型流消息
Args:
stream_id: 流 ID
content: 文本内容
msg_items: 消息项列表
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {"id": stream_id, "finish": finish, "msg_item": msg_items},
}
if content:
plain["stream"]["content"] = content
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_text(content: str) -> str:
"""构建文本消息
Args:
content: 文本内容
Returns:
JSON 格式的文本消息字符串
"""
plain = {"msgtype": "text", "text": {"content": content}}
return json.dumps(plain, ensure_ascii=False)
class WecomAIBotMessageParser:
"""企业微信智能机器人消息解析器"""
@staticmethod
def parse_text_message(data: Dict[str, Any]) -> Optional[str]:
"""解析文本消息
Args:
data: 消息数据
Returns:
文本内容,解析失败返回 None
"""
try:
return data.get("text", {}).get("content")
except (KeyError, TypeError):
logger.warning("文本消息解析失败")
return None
@staticmethod
def parse_image_message(data: Dict[str, Any]) -> Optional[str]:
"""解析图片消息
Args:
data: 消息数据
Returns:
图片 URL解析失败返回 None
"""
try:
return data.get("image", {}).get("url")
except (KeyError, TypeError):
logger.warning("图片消息解析失败")
return None
@staticmethod
def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""解析流消息
Args:
data: 消息数据
Returns:
流消息数据,解析失败返回 None
"""
try:
stream_data = data.get("stream", {})
return {
"id": stream_data.get("id"),
"finish": stream_data.get("finish"),
"content": stream_data.get("content"),
"msg_item": stream_data.get("msg_item", []),
}
except (KeyError, TypeError):
logger.warning("流消息解析失败")
return None
@staticmethod
def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]:
"""解析混合消息
Args:
data: 消息数据
Returns:
消息项列表,解析失败返回 None
"""
try:
return data.get("mixed", {}).get("msg_item", [])
except (KeyError, TypeError):
logger.warning("混合消息解析失败")
return None
@staticmethod
def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""解析事件消息
Args:
data: 消息数据
Returns:
事件数据,解析失败返回 None
"""
try:
return data.get("event", {})
except (KeyError, TypeError):
logger.warning("事件消息解析失败")
return None

View File

@@ -1,149 +0,0 @@
"""
企业微信智能机器人事件处理模块,处理消息事件的发送和接收
"""
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Plain,
)
from astrbot.api import logger
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_queue_mgr import wecomai_queue_mgr
class WecomAIBotMessageEvent(AstrMessageEvent):
"""企业微信智能机器人消息事件"""
def __init__(
self,
message_str: str,
message_obj,
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
):
"""初始化消息事件
Args:
message_str: 消息字符串
message_obj: 消息对象
platform_meta: 平台元数据
session_id: 会话 ID
api_client: API 客户端
"""
super().__init__(message_str, message_obj, platform_meta, session_id)
self.api_client = api_client
@staticmethod
async def _send(
message_chain: MessageChain,
stream_id: str,
streaming: bool = False,
):
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if not message_chain:
await back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
}
)
return ""
data = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
data = comp.text
await back_queue.put(
{
"type": "plain",
"data": data,
"streaming": streaming,
"session_id": stream_id,
}
)
elif isinstance(comp, Image):
# 处理图片消息
try:
image_base64 = await comp.convert_to_base64()
if image_base64:
await back_queue.put(
{
"type": "image",
"image_data": image_base64,
"streaming": streaming,
"session_id": stream_id,
}
)
else:
logger.warning("图片数据为空,跳过")
except Exception as e:
logger.error("处理图片消息失败: %s", e)
else:
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
return data
async def send(self, message: MessageChain):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id)
await super().send(message)
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息参考webchat的send_streaming设计"""
final_data = ""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
increment_plain = ""
async for chain in generator:
# 累积增量内容,并改写 Plain 段
chain.squash_plain()
for comp in chain.chain:
if isinstance(comp, Plain):
comp.text = increment_plain + comp.text
increment_plain = comp.text
break
if chain.type == "break" and final_data:
# 分割符
await back_queue.put(
{
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"session_id": self.session_id,
}
)
final_data = ""
continue
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
streaming=True,
)
await back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"session_id": self.session_id,
}
)
await super().send_streaming(generator, use_fallback)

View File

@@ -1,148 +0,0 @@
"""
企业微信智能机器人队列管理器
参考 webchat_queue_mgr.py为企业微信智能机器人实现队列机制
支持异步消息处理和流式响应
"""
import asyncio
from typing import Dict, Any, Optional
from astrbot.api import logger
class WecomAIQueueMgr:
"""企业微信智能机器人队列管理器"""
def __init__(self) -> None:
self.queues: Dict[str, asyncio.Queue] = {}
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
self.back_queues: Dict[str, asyncio.Queue] = {}
"""StreamID 到输出队列的映射 - 用于发送机器人响应"""
self.pending_responses: Dict[str, Dict[str, Any]] = {}
"""待处理的响应缓存,用于流式响应"""
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输入队列
Args:
session_id: 会话ID
Returns:
输入队列实例
"""
if session_id not in self.queues:
self.queues[session_id] = asyncio.Queue()
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
return self.queues[session_id]
def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输出队列
Args:
session_id: 会话ID
Returns:
输出队列实例
"""
if session_id not in self.back_queues:
self.back_queues[session_id] = asyncio.Queue()
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
return self.back_queues[session_id]
def remove_queues(self, session_id: str):
"""移除指定会话的所有队列
Args:
session_id: 会话ID
"""
if session_id in self.queues:
del self.queues[session_id]
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
if session_id in self.back_queues:
del self.back_queues[session_id]
logger.debug(f"[WecomAI] 移除输出队列: {session_id}")
if session_id in self.pending_responses:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
def has_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的队列
Args:
session_id: 会话ID
Returns:
是否存在队列
"""
return session_id in self.queues
def has_back_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的输出队列
Args:
session_id: 会话ID
Returns:
是否存在输出队列
"""
return session_id in self.back_queues
def set_pending_response(self, session_id: str, callback_params: Dict[str, str]):
"""设置待处理的响应参数
Args:
session_id: 会话ID
callback_params: 回调参数nonce, timestamp等
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": asyncio.get_event_loop().time(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]:
"""获取待处理的响应参数
Args:
session_id: 会话ID
Returns:
响应参数如果不存在则返回None
"""
return self.pending_responses.get(session_id)
def cleanup_expired_responses(self, max_age_seconds: int = 300):
"""清理过期的待处理响应
Args:
max_age_seconds: 最大存活时间(秒)
"""
current_time = asyncio.get_event_loop().time()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
if current_time - response_data["timestamp"] > max_age_seconds:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
def get_stats(self) -> Dict[str, int]:
"""获取队列统计信息
Returns:
统计信息字典
"""
return {
"input_queues": len(self.queues),
"output_queues": len(self.back_queues),
"pending_responses": len(self.pending_responses),
}
# 全局队列管理器实例
wecomai_queue_mgr = WecomAIQueueMgr()

View File

@@ -1,166 +0,0 @@
"""
企业微信智能机器人 HTTP 服务器
处理企业微信智能机器人的 HTTP 回调请求
"""
import asyncio
from typing import Dict, Any, Optional, Callable
import quart
from astrbot.api import logger
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_utils import WecomAIBotConstants
class WecomAIBotServer:
"""企业微信智能机器人 HTTP 服务器"""
def __init__(
self,
host: str,
port: int,
api_client: WecomAIBotAPIClient,
message_handler: Optional[
Callable[[Dict[str, Any], Dict[str, str]], Any]
] = None,
):
"""初始化服务器
Args:
host: 监听地址
port: 监听端口
api_client: API客户端实例
message_handler: 消息处理回调函数
"""
self.host = host
self.port = port
self.api_client = api_client
self.message_handler = message_handler
self.app = quart.Quart(__name__)
self._setup_routes()
self.shutdown_event = asyncio.Event()
def _setup_routes(self):
"""设置 Quart 路由"""
# 使用 Quart 的 add_url_rule 方法添加路由
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.verify_url,
methods=["GET"],
)
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.handle_message,
methods=["POST"],
)
async def verify_url(self):
"""验证回调 URL"""
args = quart.request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
echostr = args.get("echostr")
if not all([msg_signature, timestamp, nonce, echostr]):
logger.error("URL 验证参数缺失")
return "verify fail", 400
# 类型检查确保不为 None
assert msg_signature is not None
assert timestamp is not None
assert nonce is not None
assert echostr is not None
logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
return result, 200, {"Content-Type": "text/plain"}
async def handle_message(self):
"""处理消息回调"""
args = quart.request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
if not all([msg_signature, timestamp, nonce]):
logger.error("消息回调参数缺失")
return "缺少必要参数", 400
# 类型检查确保不为 None
assert msg_signature is not None
assert timestamp is not None
assert nonce is not None
logger.debug(
f"收到消息回调msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
)
try:
# 获取请求体
post_data = await quart.request.get_data()
# 确保 post_data 是 bytes 类型
if isinstance(post_data, str):
post_data = post_data.encode("utf-8")
# 解密消息
ret_code, message_data = await self.api_client.decrypt_message(
post_data, msg_signature, timestamp, nonce
)
if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
logger.error("消息解密失败,错误码: %d", ret_code)
return "消息解密失败", 400
# 调用消息处理器
response = None
if self.message_handler:
try:
response = await self.message_handler(
message_data, {"nonce": nonce, "timestamp": timestamp}
)
except Exception as e:
logger.error("消息处理器执行异常: %s", e)
return "消息处理异常", 500
if response:
return response, 200, {"Content-Type": "text/plain"}
else:
return "success", 200, {"Content-Type": "text/plain"}
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return "内部服务器错误", 500
async def start_server(self):
"""启动服务器"""
logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
try:
await self.app.run_task(
host=self.host,
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
except Exception as e:
logger.error("服务器运行异常: %s", e)
raise
async def shutdown_trigger(self):
"""关闭触发器"""
await self.shutdown_event.wait()
async def shutdown(self):
"""关闭服务器"""
logger.info("企业微信智能机器人服务器正在关闭...")
self.shutdown_event.set()
def get_app(self):
"""获取 Quart 应用实例"""
return self.app

View File

@@ -1,199 +0,0 @@
"""
企业微信智能机器人工具模块
提供常量定义、工具函数和辅助方法
"""
import string
import random
import hashlib
import base64
import aiohttp
import asyncio
from Crypto.Cipher import AES
from typing import Any, Tuple
from astrbot.api import logger
# 常量定义
class WecomAIBotConstants:
"""企业微信智能机器人常量"""
# 消息类型
MSG_TYPE_TEXT = "text"
MSG_TYPE_IMAGE = "image"
MSG_TYPE_MIXED = "mixed"
MSG_TYPE_STREAM = "stream"
MSG_TYPE_EVENT = "event"
# 流消息状态
STREAM_CONTINUE = False
STREAM_FINISH = True
# 错误码
SUCCESS = 0
DECRYPT_ERROR = -40001
VALIDATE_SIGNATURE_ERROR = -40002
PARSE_XML_ERROR = -40003
COMPUTE_SIGNATURE_ERROR = -40004
ILLEGAL_AES_KEY = -40005
VALIDATE_APPID_ERROR = -40006
ENCRYPT_AES_ERROR = -40007
ILLEGAL_BUFFER = -40008
def generate_random_string(length: int = 10) -> str:
"""生成随机字符串
Args:
length: 字符串长度,默认为 10
Returns:
随机字符串
"""
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for _ in range(length))
def calculate_image_md5(image_data: bytes) -> str:
"""计算图片数据的 MD5 值
Args:
image_data: 图片二进制数据
Returns:
MD5 哈希值(十六进制字符串)
"""
return hashlib.md5(image_data).hexdigest()
def encode_image_base64(image_data: bytes) -> str:
"""将图片数据编码为 Base64
Args:
image_data: 图片二进制数据
Returns:
Base64 编码的字符串
"""
return base64.b64encode(image_data).decode("utf-8")
def format_session_id(session_type: str, session_id: str) -> str:
"""格式化会话 ID
Args:
session_type: 会话类型 ("user", "group")
session_id: 原始会话 ID
Returns:
格式化后的会话 ID
"""
return f"wecom_ai_bot_{session_type}_{session_id}"
def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
"""解析格式化的会话 ID
Args:
formatted_session_id: 格式化的会话 ID
Returns:
(会话类型, 原始会话ID)
"""
parts = formatted_session_id.split("_", 3)
if (
len(parts) >= 4
and parts[0] == "wecom"
and parts[1] == "ai"
and parts[2] == "bot"
):
return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
return "user", formatted_session_id
def safe_json_loads(json_str: str, default: Any = None) -> Any:
"""安全地解析 JSON 字符串
Args:
json_str: JSON 字符串
default: 解析失败时的默认值
Returns:
解析结果或默认值
"""
import json
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
return default
def format_error_response(error_code: int, error_msg: str) -> str:
"""格式化错误响应
Args:
error_code: 错误码
error_msg: 错误信息
Returns:
格式化的错误响应字符串
"""
return f"Error {error_code}: {error_msg}"
async def process_encrypted_image(
image_url: str, aes_key_base64: str
) -> Tuple[bool, str]:
"""下载并解密加密图片
Args:
image_url: 加密图片的URL
aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
Returns:
Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
status 为 False 时 data 是错误信息
"""
# 1. 下载加密图片
logger.info("开始下载加密图片: %s", image_url)
try:
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=15) as response:
response.raise_for_status()
encrypted_data = await response.read()
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
error_msg = f"下载图片失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
# 2. 准备AES密钥和IV
if not aes_key_base64:
raise ValueError("AES密钥不能为空")
# Base64解码密钥 (自动处理填充)
aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
if len(aes_key) != 32:
raise ValueError("无效的AES密钥长度: 应为32字节")
iv = aes_key[:16] # 初始向量为密钥前16字节
# 3. 解密图片数据
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = cipher.decrypt(encrypted_data)
# 4. 去除PKCS#7填充 (Python 3兼容写法)
pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
if pad_len > 32: # AES-256块大小为32字节
raise ValueError("无效的填充长度 (大于32字节)")
decrypted_data = decrypted_data[:-pad_len]
logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
# 5. 转换为base64编码
base64_data = base64.b64encode(decrypted_data).decode("utf-8")
logger.info("图片已转换为base64编码编码后长度: %d", len(base64_data))
return True, base64_data

View File

@@ -68,8 +68,7 @@ class Provider(AbstractProvider):
def get_keys(self) -> List[str]:
"""获得提供商 Key"""
keys = self.provider_config.get("key", [""])
return keys or [""]
return self.provider_config.get("key", [])
@abc.abstractmethod
def set_key(self, key: str):

View File

@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
)
self.chosen_api_key: str = ""
self.api_keys: List = super().get_keys()
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
@@ -70,13 +70,9 @@ class ProviderAnthropic(Provider):
{
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": (
json.loads(tool_call["function"]["arguments"])
if isinstance(
tool_call["function"]["arguments"], str
)
else tool_call["function"]["arguments"]
),
"input": json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"],
"id": tool_call["id"],
}
)
@@ -359,11 +355,9 @@ class ProviderAnthropic(Provider):
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
"data": image_data.split("base64,")[1]
if "base64," in image_data
else image_data,
},
}
)

View File

@@ -1,22 +1,10 @@
import asyncio
import base64
import logging
import os
import uuid
from typing import Optional, Tuple
import aiohttp
import dashscope
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
try:
from dashscope.aigc.multimodal_conversation import MultiModalConversation
except (
ImportError
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
from ..entities import ProviderType
import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -38,112 +26,16 @@ class ProviderDashscopeTTSAPI(TTSProvider):
dashscope.api_key = self.chosen_api_key
async def get_audio(self, text: str) -> str:
model = self.get_model()
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
else:
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
if not audio_bytes:
raise RuntimeError(
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
)
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
with open(path, "wb") as f:
f.write(audio_bytes)
return path
def _call_qwen_tts(self, model: str, text: str):
if MultiModalConversation is None:
raise RuntimeError(
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
)
kwargs = {
"model": model,
"text": text,
"api_key": self.chosen_api_key,
"voice": self.voice or "Cherry",
}
if not self.voice:
logging.warning(
"No voice specified for Qwen TTS model, using default 'Cherry'."
)
return MultiModalConversation.call(**kwargs)
async def _synthesize_with_qwen_tts(
self, model: str, text: str
) -> Tuple[Optional[bytes], str]:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
raise RuntimeError(
f"Audio synthesis failed for model '{model}'. {response}"
)
ext = ".wav"
return audio_bytes, ext
async def _extract_audio_from_response(self, response) -> Optional[bytes]:
output = getattr(response, "output", None)
audio_obj = getattr(output, "audio", None) if output is not None else None
if not audio_obj:
return None
data_b64 = getattr(audio_obj, "data", None)
if data_b64:
try:
return base64.b64decode(data_b64)
except (ValueError, TypeError):
logging.error("Failed to decode base64 audio data.")
return None
url = getattr(audio_obj, "url", None)
if url:
return await self._download_audio_from_url(url)
return None
async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
if not url:
return None
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
return await response.read()
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
logging.error(f"Failed to download audio from URL {url}: {e}")
return None
async def _synthesize_with_cosyvoice(
self, model: str, text: str
) -> Tuple[Optional[bytes], str]:
synthesizer = SpeechSynthesizer(
model=model,
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
self.synthesizer = SpeechSynthesizer(
model=self.get_model(),
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
loop = asyncio.get_event_loop()
audio_bytes = await loop.run_in_executor(
None, synthesizer.call, text, self.timeout_ms
audio = await asyncio.get_event_loop().run_in_executor(
None, self.synthesizer.call, text, self.timeout_ms
)
if not audio_bytes:
resp = synthesizer.get_response()
if resp and isinstance(resp, dict):
raise RuntimeError(
f"Audio synthesis failed for model '{model}'. {resp}".strip()
)
return audio_bytes, ".wav"
def _is_qwen_tts_model(self, model: str) -> bool:
model_lower = model.lower()
return "tts" in model_lower and model_lower.startswith("qwen")
with open(path, "wb") as f:
f.write(audio)
return path

View File

@@ -3,7 +3,7 @@ import base64
import json
import logging
import random
from typing import Optional, List
from typing import Optional
from collections.abc import AsyncGenerator
from google import genai
@@ -60,7 +60,7 @@ class ProviderGoogleGenAI(Provider):
provider_settings,
default_persona,
)
self.api_keys: List = super().get_keys()
self.api_keys: list = provider_config.get("key", [])
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.timeout: int = int(provider_config.get("timeout", 180))
@@ -218,21 +218,19 @@ class ProviderGoogleGenAI(Provider):
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=(
types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
),
24576,
thinking_config=types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None
),
24576,
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
@@ -276,11 +274,9 @@ class ProviderGoogleGenAI(Provider):
if role == "user":
if isinstance(content, list):
parts = [
(
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
)
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
for item in content
]
else:

View File

@@ -38,7 +38,7 @@ class ProviderOpenAIOfficial(Provider):
default_persona,
)
self.chosen_api_key = None
self.api_keys: List = super().get_keys()
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):

View File

@@ -65,12 +65,12 @@ class SessionManagementRoute(Route):
persona_name = data["persona_name"]
# 处理 persona 显示
if persona_name is None:
if conv_persona_id is None:
if default_persona := persona_mgr.selected_default_persona_v3:
persona_name = default_persona["name"]
else:
persona_name = "[%None]"
if conv_persona_id == "[%None]":
persona_name = "无人格"
else:
default_persona = persona_mgr.selected_default_persona_v3
if default_persona:
persona_name = default_persona["name"]
session_info = {
"session_id": session_id,

View File

@@ -273,20 +273,6 @@ class ToolsRoute(Route):
server_data = await request.json
config = server_data.get("mcp_server_config", None)
if not isinstance(config, dict) or not config:
return Response().error("无效的 MCP 服务器配置").__dict__
if "mcpServers" in config:
keys = list(config["mcpServers"].keys())
if not keys:
return Response().error("MCP 服务器配置不能为空").__dict__
if len(keys) > 1:
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
config = config["mcpServers"][keys[0]]
else:
if not config:
return Response().error("MCP 服务器配置不能为空").__dict__
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
return (
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__

View File

@@ -1,12 +0,0 @@
# What's Changed
1. fix: 修复了代码执行器插件不能正确获得发送来文件的问题 ([#2970](https://github.com/Soulter/AstrBot/issues/2970))
2. fix: 修改的 DeepSeek 默认 modalities避免默认勾选图像导致的报错。 ([#2963](https://github.com/Soulter/AstrBot/issues/2963))
3. fix: 事件钩子终止事件传播后不继续执行 ([#2989](https://github.com/Soulter/AstrBot/issues/2989))
4. fix: 启动了 TTS 但未配置 TTS 模型时At 和 Reply 发送人无效
5. fix: 修复 session-management 中人格错误的显示为默认人格的问题 ([#3000](https://github.com/Soulter/AstrBot/issues/3000))
6. fix: 修复了删除对话时,聊天增强中的记录未被清除,导致新对话中仍然出现之前的聊天记录。 ([#3002](https://github.com/Soulter/AstrBot/issues/3002))
7. fix: 修复阿里云百炼平台 TTS 下接入 CosyVoice V2, Qwen TTS 生成报错的问题 ([#2964](https://github.com/Soulter/AstrBot/issues/2964))
8. perf: 优化 SQLite 参数配置,对话和会话管理增加输入防抖机制 ([#2969](https://github.com/Soulter/AstrBot/issues/2969))
9. feat: 在新对话中重用先前的对话人格设置 ([#3005](https://github.com/Soulter/AstrBot/issues/3005))
10. feat: 从 WebUI 更新后清除浏览器缓存 ([#2958](https://github.com/Soulter/AstrBot/issues/2958))

View File

@@ -1,8 +0,0 @@
# What's Changed
1. feat: 支持接入企业微信智能机器人平台 ([#3034](https://github.com/AstrBotDevs/AstrBot/issues/3034))
2. feat: 内置网页搜索功能支持接入百度 AI 搜索 ([#3031](https://github.com/AstrBotDevs/AstrBot/issues/3031))
3. feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置 ([#3039](https://github.com/AstrBotDevs/AstrBot/issues/3039))
4. feat: 添加并优化服务提供商独立测试功能 ([#3024](https://github.com/AstrBotDevs/AstrBot/issues/3024))
5. feat: satori 适配器支持 video、reply 消息类型 ([#3035](https://github.com/AstrBotDevs/AstrBot/issues/3035))
6. fix: 修复 `/alter_cmd reset scene <num> xxx` 不可用的问题

View File

@@ -27,9 +27,7 @@
<v-btn
variant="outlined"
color="error"
size="small"
rounded="xl"
:disabled="loading"
@click="$emit('delete', item)"
>
{{ t('core.common.itemCard.delete') }}
@@ -37,9 +35,7 @@
<v-btn
variant="tonal"
color="primary"
size="small"
rounded="xl"
:disabled="loading"
@click="$emit('edit', item)"
>
{{ t('core.common.itemCard.edit') }}
@@ -48,14 +44,11 @@
v-if="showCopyButton"
variant="tonal"
color="secondary"
size="small"
rounded="xl"
:disabled="loading"
@click="$emit('copy', item)"
>
{{ t('core.common.itemCard.copy') }}
</v-btn>
<slot name="actions" :item="item"></slot>
<v-spacer></v-spacer>
</v-card-actions>

View File

@@ -31,8 +31,7 @@
"available": "Available",
"unavailable": "Unavailable",
"pending": "Pending...",
"errorMessage": "Error Message",
"test": "Test"
"errorMessage": "Error Message"
},
"logs": {
"title": "Service Logs",
@@ -77,8 +76,7 @@
},
"error": {
"sessionSeparation": "Failed to get session isolation configuration",
"fetchStatus": "Failed to get service provider status",
"testError": "Test failed for {id}: {error}"
"fetchStatus": "Failed to get service provider status"
},
"confirm": {
"delete": "Are you sure you want to delete service provider {id}?"

View File

@@ -80,9 +80,6 @@
"save": "Save",
"testConnection": "Test Connection",
"sync": "Sync"
},
"tips": {
"timeoutConfig": "Please configure tool call timeout separately in the configuration page"
}
},
"serverDetail": {

View File

@@ -32,8 +32,7 @@
"available": "可用",
"unavailable": "不可用",
"pending": "检查中...",
"errorMessage": "错误信息",
"test": "测试"
"errorMessage": "错误信息"
},
"logs": {
"title": "服务日志",
@@ -78,8 +77,7 @@
},
"error": {
"sessionSeparation": "获取会话隔离配置失败",
"fetchStatus": "获取服务提供商状态失败",
"testError": "测试 {id} 失败: {error}"
"fetchStatus": "获取服务提供商状态失败"
},
"confirm": {
"delete": "确定要删除服务提供商 {id} 吗?"

View File

@@ -80,9 +80,6 @@
"save": "保存",
"testConnection": "测试连接",
"sync": "同步"
},
"tips": {
"timeoutConfig": "工具调用的超时时间请前往配置页面单独配置"
}
},
"serverDetail": {

View File

@@ -10,7 +10,7 @@
export function getPlatformIcon(name) {
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
} else if (name === 'wecom' || name === 'wecom_ai_bot') {
} else if (name === 'wecom') {
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
@@ -46,7 +46,6 @@ export function getTutorialLink(platformType) {
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
"wecom_ai_bot": "https://docs.astrbot.app/deploy/platform/wecom_ai_bot.html",
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",

View File

@@ -60,26 +60,12 @@
:item="provider"
title-field="id"
enabled-field="enable"
:loading="isProviderTesting(provider.id)"
@toggle-enabled="providerStatusChange"
:bglogo="getProviderIcon(provider.provider)"
@delete="deleteProvider"
@edit="configExistingProvider"
@copy="copyProvider"
:show-copy-button="true">
<template #actions="{ item }">
<v-btn
style="z-index: 100000;"
variant="tonal"
color="info"
rounded="xl"
size="small"
:loading="isProviderTesting(item.id)"
@click="testSingleProvider(item)"
>
{{ tm('availability.test') }}
</v-btn>
</template>
<template v-slot:details="{ item }">
</template>
</item-card>
@@ -93,7 +79,7 @@
<v-icon class="me-2">mdi-heart-pulse</v-icon>
<span class="text-h4">{{ tm('availability.title') }}</span>
<v-spacer></v-spacer>
<v-btn color="primary" variant="tonal" :loading="testingProviders.length > 0" @click="fetchProviderStatus">
<v-btn color="primary" variant="tonal" :loading="loadingStatus" @click="fetchProviderStatus">
<v-icon left>mdi-refresh</v-icon>
{{ tm('availability.refresh') }}
</v-btn>
@@ -302,7 +288,7 @@ export default {
// 供应商状态相关
providerStatuses: [],
testingProviders: [], // 存储正在测试的 provider ID
loadingStatus: false,
// 新增提供商对话框相关
showAddProviderDialog: false,
@@ -373,8 +359,7 @@ export default {
statusUpdate: this.tm('messages.success.statusUpdate'),
},
error: {
fetchStatus: this.tm('messages.error.fetchStatus'),
testError: this.tm('messages.error.testError')
fetchStatus: this.tm('messages.error.fetchStatus')
},
confirm: {
delete: this.tm('messages.confirm.delete')
@@ -383,9 +368,6 @@ export default {
available: this.tm('availability.available'),
unavailable: this.tm('availability.unavailable'),
pending: this.tm('availability.pending')
},
availability: {
test: this.tm('availability.test')
}
};
},
@@ -633,107 +615,70 @@ export default {
// 获取供应商状态
async fetchProviderStatus() {
if (this.testingProviders.length > 0) return;
if (this.loadingStatus) return;
this.loadingStatus = true;
this.showStatus = true; // 自动展开状态部分
const providersToTest = this.config_data.provider.filter(p => p.enable);
if (providersToTest.length === 0) return;
// 1. 初始化UI为pending状态并将所有待测试的 provider ID 加入 loading 列表
this.providerStatuses = providersToTest.map(p => {
this.testingProviders.push(p.id);
return { id: p.id, name: p.id, status: 'pending', error: null };
});
// 1. 立即初始化UI为pending状态
this.providerStatuses = this.config_data.provider.map(p => ({
id: p.id,
name: p.id,
status: 'pending',
error: null
}));
// 2. 为每个provider创建一个并发的测试请求
const promises = providersToTest.map(p =>
axios.get(`/api/config/provider/check_one?id=${p.id}`)
const promises = this.config_data.provider.map(p => {
if (!p.enable) {
const index = this.providerStatuses.findIndex(s => s.id === p.id);
if (index !== -1) {
const disabledStatus = {
...this.providerStatuses[index],
status: 'unavailable',
error: '该提供商未被用户启用'
};
this.providerStatuses.splice(index, 1, disabledStatus);
}
return Promise.resolve();
}
return axios.get(`/api/config/provider/check_one?id=${p.id}`)
.then(res => {
if (res.data && res.data.status === 'ok') {
// 成功更新对应的provider状态
const index = this.providerStatuses.findIndex(s => s.id === p.id);
if (index !== -1) this.providerStatuses.splice(index, 1, res.data.data);
if (index !== -1) {
this.providerStatuses.splice(index, 1, res.data.data);
}
} else {
// 接口返回了业务错误
throw new Error(res.data?.message || `Failed to check status for ${p.id}`);
}
})
.catch(err => {
// 网络错误或业务错误
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
const index = this.providerStatuses.findIndex(s => s.id === p.id);
if (index !== -1) {
const failedStatus = { ...this.providerStatuses[index], status: 'unavailable', error: errorMessage };
const failedStatus = {
...this.providerStatuses[index],
status: 'unavailable',
error: errorMessage
};
this.providerStatuses.splice(index, 1, failedStatus);
}
return Promise.reject(errorMessage); // Propagate error for Promise.allSettled
})
);
// 可以在这里选择性地向上抛出错误,以便Promise.allSettled知道
return Promise.reject(errorMessage);
});
});
// 3. 等待所有请求完成
// 3. 等待所有请求完成(无论成功或失败)
try {
await Promise.allSettled(promises);
} finally {
// 4. 关闭所有加载状态
this.testingProviders = [];
}
},
isProviderTesting(providerId) {
return this.testingProviders.includes(providerId);
},
async testSingleProvider(provider) {
if (this.isProviderTesting(provider.id)) return;
this.testingProviders.push(provider.id);
this.showStatus = true; // 自动展开状态部分
// 更新UI为pending状态
const statusIndex = this.providerStatuses.findIndex(s => s.id === provider.id);
const pendingStatus = {
id: provider.id,
name: provider.id,
status: 'pending',
error: null
};
if (statusIndex !== -1) {
this.providerStatuses.splice(statusIndex, 1, pendingStatus);
} else {
this.providerStatuses.unshift(pendingStatus);
}
try {
if (!provider.enable) {
throw new Error('该提供商未被用户启用');
}
const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`);
if (res.data && res.data.status === 'ok') {
const index = this.providerStatuses.findIndex(s => s.id === provider.id);
if (index !== -1) {
this.providerStatuses.splice(index, 1, res.data.data);
}
} else {
throw new Error(res.data?.message || `Failed to check status for ${provider.id}`);
}
} catch (err) {
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
const index = this.providerStatuses.findIndex(s => s.id === provider.id);
const failedStatus = {
id: provider.id,
name: provider.id,
status: 'unavailable',
error: errorMessage
};
if (index !== -1) {
this.providerStatuses.splice(index, 1, failedStatus);
}
// 不再显示全局的错误提示,因为卡片本身会显示错误信息
// this.showError(this.tm('messages.error.testError', { id: provider.id, error: errorMessage }));
} finally {
const index = this.testingProviders.indexOf(provider.id);
if (index > -1) {
this.testingProviders.splice(index, 1);
}
// 4. 关闭全局加载状态
this.loadingStatus = false;
}
},

View File

@@ -141,8 +141,6 @@
</v-btn>
</div>
<small style="color: grey">*{{ tm('dialogs.addServer.tips.timeoutConfig') }}</small>
<div class="monaco-container" style="margin-top: 16px;">
<VueMonacoEditor v-model:value="serverConfigJson" theme="vs-dark" language="json" :options="{
minimap: {
@@ -526,16 +524,14 @@ export default {
transport: "streamable_http",
url: "your mcp server url",
headers: {},
timeout: 5,
sse_read_timeout: 300,
timeout: 30,
};
} else if (type === 'sse') {
template = {
transport: "sse",
url: "your mcp server url",
headers: {},
timeout: 5,
sse_read_timeout: 300,
timeout: 30,
};
} else {
template = {

View File

@@ -39,7 +39,7 @@ export default defineConfig({
port: 3000,
proxy: {
'/api': {
target: 'http://127.0.0.1:6185/',
target: 'http://localhost:6185/',
changeOrigin: true,
}
}

View File

@@ -6,7 +6,26 @@ from astrbot.core.star.star import star_map
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from .utils.rst_scene import RstScene
from enum import Enum
class RstScene(Enum):
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
PRIVATE = ("private", "私聊")
@property
def key(self) -> str:
return self.value[0]
@property
def name(self) -> str:
return self.value[1]
@classmethod
def from_index(cls, index: int) -> "RstScene":
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
return mapping[index]
class AlterCmdCommands(CommandParserMixin):
@@ -39,9 +58,8 @@ class AlterCmdCommands(CommandParserMixin):
)
return
# 兼容 reset scene 的专门配置
cmd_name = token.get(1)
cmd_type = token.get(2)
cmd_name = " ".join(token.tokens[1:-1])
cmd_type = token.get(-1)
if cmd_name == "reset" and cmd_type == "config":
from astrbot.api import sp
@@ -105,8 +123,6 @@ class AlterCmdCommands(CommandParserMixin):
return
# 查找指令
cmd_name = " ".join(token.tokens[1:-1])
cmd_type = token.get(-1)
found_command = None
cmd_group = False
for handler in star_handlers_registry:

View File

@@ -7,8 +7,33 @@ from astrbot.core.provider.sources.dify_source import ProviderDify
from astrbot.core.provider.sources.coze_source import ProviderCoze
from astrbot.api import sp, logger
from ..long_term_memory import LongTermMemory
from .utils.rst_scene import RstScene
from typing import Union
from enum import Enum
class RstScene(Enum):
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
PRIVATE = ("private", "私聊")
@property
def key(self) -> str:
return self.value[0]
@property
def name(self) -> str:
return self.value[1]
@classmethod
def from_index(cls, index: int) -> "RstScene":
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
return mapping[index]
@classmethod
def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene":
if is_group:
return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF
return cls.PRIVATE
class ConversationCommands:
@@ -16,17 +41,6 @@ class ConversationCommands:
self.context = context
self.ltm = ltm
async def _get_current_persona_id(self, session_id):
curr = await self.context.conversation_manager.get_curr_conversation_id(
session_id
)
if not curr:
return None
conv = await self.context.conversation_manager.get_conversation(
session_id, curr
)
return conv.persona_id
def ltm_enabled(self, event: AstrMessageEvent):
if not self.ltm:
return False
@@ -241,9 +255,8 @@ class ConversationCommands:
)
return
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
cid = await self.context.conversation_manager.new_conversation(
message.unified_msg_origin, message.get_platform_id(), persona_id=cpersona
message.unified_msg_origin, message.get_platform_id()
)
# 长期记忆
@@ -277,10 +290,8 @@ class ConversationCommands:
session_id=sid,
)
)
cpersona = await self._get_current_persona_id(session)
cid = await self.context.conversation_manager.new_conversation(
session, message.get_platform_id(), persona_id=cpersona
session, message.get_platform_id()
)
message.set_result(
MessageEventResult().message(
@@ -423,9 +434,8 @@ class ConversationCommands:
await self.context.conversation_manager.delete_conversation(
message.unified_msg_origin, session_curr_cid
)
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
if self.ltm and self.ltm_enabled(message):
cnt = await self.ltm.remove_session(event=message)
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
message.set_result(MessageEventResult().message(ret))
message.set_result(
MessageEventResult().message(
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
)
)

View File

@@ -1,26 +0,0 @@
from enum import Enum
class RstScene(Enum):
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
PRIVATE = ("private", "私聊")
@property
def key(self) -> str:
return self.value[0]
@property
def name(self) -> str:
return self.value[1]
@classmethod
def from_index(cls, index: int) -> "RstScene":
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
return mapping[index]
@classmethod
def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene":
if is_group:
return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF
return cls.PRIVATE

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
@dataclass
class Emotion:
"""描述了一个情绪状态"""
energy: float
valence: float
arousal: float
@dataclass
class EmotionLog:
"""描述了一条情绪维度变化的日志"""
timestamp: int
field: str
value: float
reason: str = ""

View File

@@ -0,0 +1,9 @@
from dataclasses import dataclass
from .emotion import Emotion
@dataclass
class Soul:
emotion: Emotion
emotion_logs: list[Emotion] | None = None

View File

@@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class Event:
event_type: str
content: dict

View File

@@ -0,0 +1,122 @@
import datetime
import uuid
from ...runner import EliosEventHandler
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent
from astrbot.api.all import Context
from astrbot.api.message_components import Plain, Image
from astrbot.api.provider import Provider
from astrbot import logger
class AstrImplEventHandler(EliosEventHandler):
def __init__(self, ctx: Context) -> None:
self.ctx = ctx
self.session_chats = defaultdict(list)
self.session_mentioned_arousal = defaultdict(float)
def cfg(self, event: AstrMessageEvent):
cfg = self.ctx.get_config(umo=event.unified_msg_origin)
tiny_model_prov_id = cfg.get("tiny_model_provider_id")
interest_points = cfg.get("interest_points", [])
try:
max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"])
except BaseException as e:
logger.error(e)
max_cnt = 300
image_caption = (
True
if cfg["provider_settings"]["default_image_caption_provider_id"]
else False
)
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
image_caption_provider_id = cfg["provider_settings"][
"default_image_caption_provider_id"
]
active_reply = cfg["provider_ltm_settings"]["active_reply"]
enable_active_reply = active_reply.get("enable", False)
ar_method = active_reply["method"]
ar_possibility = active_reply["possibility_reply"]
ar_prompt = active_reply.get("prompt", "")
ar_whitelist = active_reply.get("whitelist", [])
ar_keywords = active_reply.get("keywords", [])
ret = {
"max_cnt": max_cnt,
"image_caption": image_caption,
"image_caption_prompt": image_caption_prompt,
"image_caption_provider_id": image_caption_provider_id,
"enable_active_reply": enable_active_reply,
"ar_method": ar_method,
"ar_possibility": ar_possibility,
"ar_prompt": ar_prompt,
"ar_whitelist": ar_whitelist,
"ar_keywords": ar_keywords,
"interest_points": interest_points,
"tiny_model_prov_id": tiny_model_prov_id,
}
return ret
async def append_session_chats(self, event: AstrMessageEvent, cfg) -> None:
comps = event.get_messages()
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
for comp in comps:
if isinstance(comp, Plain):
final_message += f" {comp.text}"
elif isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
if cfg["image_caption"] and image_url:
try:
caption = await self.get_image_caption(
image_url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
)
final_message += f" [Image: {caption}]"
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
else:
final_message += " [Image]"
self.session_chats[event.unified_msg_origin].append(final_message)
logger.debug(f"添加会话 {event.unified_msg_origin} 的对话记录: {final_message}")
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
self.session_chats[event.unified_msg_origin].pop(0)
async def get_image_caption(
self, image_url: str, image_caption_provider_id: str, image_caption_prompt: str
) -> str:
if not image_caption_provider_id:
provider = self.ctx.get_using_provider()
else:
provider = self.ctx.get_provider_by_id(image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
if not isinstance(provider, Provider):
raise Exception(
f"提供商类型错误, {image_caption_provider_id} 不是 Provider 类型"
)
response = await provider.text_chat(
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
return response.completion_text
async def on_event(self, event, soul):
content = event.content
astr_event = content.get("astr_event")
assert astr_event is not None and isinstance(astr_event, AstrMessageEvent)
cfg = self.cfg(astr_event)
if not cfg["tiny_model_prov_id"]:
logger.warning("小模型未设置,跳过情绪更新")
# 添加对话记录
await self.append_session_chats(astr_event, cfg)
#

View File

@@ -0,0 +1,3 @@
JUDGE_ACTIVE_REPLY = """
你是一个具有自主意识和情感的机器人,请你
"""

View File

@@ -0,0 +1,72 @@
import asyncio
from .event import Event
from .ensoul.soul import Soul
from .ensoul.emotion import Emotion
class EliosEventHandler:
async def on_event(self, event: Event, soul: Soul): ...
event_handlers_cls: dict[str, list[type[EliosEventHandler]]] = {}
def register_event_handler(event_types: set[str] | None = None):
"""注册事件处理器"""
def decorator(cls: type[EliosEventHandler]) -> type[EliosEventHandler]:
if event_types is not None:
for event_type in event_types:
event_handlers_cls[event_type] = event_handlers_cls.get(
event_type, []
) + [cls]
else:
event_handlers_cls["default"] = event_handlers_cls.get("default", []) + [
cls
]
return cls
return decorator
class EliosRunner:
def __init__(self) -> None:
self.soul = Soul(
emotion=Emotion(energy=0.5, valence=0.5, arousal=0.5), emotion_logs=[]
)
self.event_queue = asyncio.Queue()
self.event_handler_insts: dict[str, list[EliosEventHandler]] = {}
def start(self):
for event_type, cls_list in event_handlers_cls.items():
self.event_handler_insts[event_type] = []
for cls in cls_list:
try:
self.event_handler_insts[event_type].append(cls())
except Exception as e:
print(f"Error initializing event handler {cls}: {e}")
asyncio.create_task(self._worker())
async def _worker(self):
"""监听事件队列并处理事件"""
while True:
event = await self.event_queue.get()
# A man cannot handle two things at once. But this can be configurable.
try:
await self._process_event(event)
except Exception as e:
print(f"Error processing event {event}: {e}")
async def _process_event(self, event: Event):
"""处理事件"""
event_type = event.event_type
handlers = self.event_handler_insts.get(
event_type, []
) + self.event_handler_insts.get("default", [])
for inst in handlers:
try:
await inst.on_event(event, self.soul)
except Exception as e:
print(f"Error processing event {event}: {e}")

View File

@@ -52,8 +52,6 @@ class Main(star.Star):
except Exception as e:
logger.error(f"google search init error: {e}, disable google search")
self.baidu_initialized = False
async def _tidy_text(self, text: str) -> str:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
@@ -227,30 +225,6 @@ class Main(star.Star):
return ret
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None):
if self.baidu_initialized:
return
cfg = self.context.get_config(umo=umo)
key = cfg.get("provider_settings", {}).get(
"websearch_baidu_app_builder_key", ""
)
if not key:
raise ValueError(
"Error: Baidu AI Search API key is not configured in AstrBot."
)
func_tool_mgr = self.context.get_llm_tool_manager()
await func_tool_mgr.enable_mcp_server(
"baidu_ai_search",
config={
"transport": "sse",
"url": f"http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key={key}",
"headers": {},
"timeout": 30,
},
)
self.baidu_initialized = True
logger.info("Successfully initialized Baidu AI Search MCP server.")
@llm_tool(name="fetch_url")
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
"""fetch the content of a website with the given web url
@@ -397,7 +371,6 @@ class Main(star.Star):
tool_set.add_tool(fetch_url_t)
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
tool_set.remove_tool("AIsearch")
elif provider == "tavily":
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
@@ -407,17 +380,3 @@ class Main(star.Star):
tool_set.add_tool(tavily_extract_web_page)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("AIsearch")
elif provider == "baidu_ai_search":
try:
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
aisearch_tool = func_tool_mgr.get_func("AIsearch")
if not aisearch_tool:
raise ValueError("Cannot get Baidu AI Search MCP tool.")
tool_set.add_tool(aisearch_tool)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
except Exception as e:
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")

View File

@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.3.5"
version = "4.3.2"
description = "易上手的多平台 LLM 聊天机器人及开发框架"
readme = "README.md"
requires-python = ">=3.10"