From 3e07fbf3dc14e85edd679d7f9d8dc79a8e4ec9d6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 May 2025 11:32:35 -0400 Subject: [PATCH 01/19] =?UTF-8?q?feat:=20=E5=BE=AE=E4=BF=A1=E5=AE=A2?= =?UTF-8?q?=E6=9C=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 8 +- .../platform/sources/wecom/wecom_adapter.py | 91 ++++++- .../platform/sources/wecom/wecom_event.py | 131 +++++---- .../core/platform/sources/wecom/wecom_kf.py | 255 ++++++++++++++++++ .../sources/wecom/wecom_kf_message.py | 136 ++++++++++ 5 files changed, 568 insertions(+), 53 deletions(-) create mode 100644 astrbot/core/platform/sources/wecom/wecom_kf.py create mode 100644 astrbot/core/platform/sources/wecom/wecom_kf_message.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d4f92438..33e4bc3d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -159,6 +159,7 @@ CONFIG_METADATA_2 = { "secret": "", "token": "", "encoding_aes_key": "", + "kf_name": "", "api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/", "callback_server_host": "0.0.0.0", "port": 6195, @@ -193,6 +194,11 @@ CONFIG_METADATA_2 = { }, }, "items": { + "kf_name": { + "description": "微信客服账号名", + "type": "string", + "hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取" + }, "telegram_token": { "description": "Bot Token", "type": "string", @@ -237,7 +243,7 @@ CONFIG_METADATA_2 = { "secret": { "description": "secret", "type": "string", - "hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。", + "hint": "必填项。", }, "enable_group_c2c": { "description": "启用消息列表单聊", diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index c6b7c096..78465880 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -2,6 +2,7 @@ import sys import uuid import asyncio import quart +import aiohttp from astrbot.api.platform import ( Platform, @@ -20,10 +21,14 @@ from requests import Response from wechatpy.enterprise.crypto import WeChatCrypto from wechatpy.enterprise import WeChatClient from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage +from wechatpy.messages import BaseMessage from wechatpy.exceptions import InvalidSignatureException from wechatpy.enterprise import parse_message from .wecom_event import WecomPlatformEvent +from .wecom_kf import WeChatKF +from .wecom_kf_message import WeChatKFMessage + if sys.version_info >= (3, 12): from typing import override else: @@ -131,9 +136,39 @@ class WecomPlatformAdapter(Platform): self.config["corpid"].strip(), self.config["secret"].strip(), ) + # inject + self.wechat_kf_api = WeChatKF(client=self.client) + self.wechat_kf_message_api = WeChatKFMessage(self.client) + self.client.kf = self.wechat_kf_api + self.client.kf_message = self.wechat_kf_message_api + self.client.API_BASE_URL = self.api_base_url - async def callback(msg): + # 微信客服 + self.kf_name = self.config.get("kf_name", None) + + async def callback(msg: BaseMessage): + if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": + + def get_latest_msg_item() -> dict | None: + token = msg._data["Token"] + kfid = msg._data["OpenKfId"] + has_more = 1 + ret = {} + while has_more: + ret = self.wechat_kf_api.sync_msg(token, kfid) + has_more = ret["has_more"] + msg_list = ret.get("msg_list", []) + if msg_list: + return msg_list[-1] + return None + + msg_new = await asyncio.get_event_loop().run_in_executor( + None, get_latest_msg_item + ) + if msg_new: + await self.convert_wechat_kf_message(msg_new) + return await self.convert_message(msg) self.server.callback = callback @@ -153,9 +188,39 @@ class WecomPlatformAdapter(Platform): @override async def run(self): + loop = asyncio.get_event_loop() + if self.kf_name: + try: + acc_list = ( + await loop.run_in_executor( + None, self.wechat_kf_api.get_account_list + ) + ).get("account_list", []) + logger.debug(f"获取到微信客服列表: {str(acc_list)}") + for acc in acc_list: + name = acc.get("name", None) + if name != self.kf_name: + continue + open_kfid = acc.get("open_kfid", None) + if not open_kfid: + logger.error("获取微信客服失败,open_kfid 为空。") + logger.debug(f"Found open_kfid: {str(open_kfid)}") + kf_url = ( + await loop.run_in_executor( + None, + self.wechat_kf_api.add_contact_way, + open_kfid, + "astrbot_placeholder", + ) + ).get("url", "") + logger.info( + f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}" + ) + except Exception as e: + logger.error(e) await self.server.start_polling() - async def convert_message(self, msg): + async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: abm = AstrBotMessage() if msg.type == "text": assert isinstance(msg, TextMessage) @@ -218,10 +283,32 @@ class WecomPlatformAdapter(Platform): abm.timestamp = msg.time abm.session_id = abm.sender.user_id abm.raw_message = msg + else: + logger.warning(f"暂未实现的事件: {msg.type}") + return logger.info(f"abm: {abm}") await self.handle_msg(abm) + async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: + msgtype = msg.get("msgtype", None) + abm = AstrBotMessage() + abm.raw_message = msg + abm.raw_message["_wechat_kf_flag"] = None # 方便处理 + abm.self_id = msg["open_kfid"] + if msgtype == "text": + external_userid = msg.get("external_userid", None) + text = msg.get("text", {}).get("content", "").strip() + abm.message = [Plain(text=text)] + abm.sender = MessageMember(external_userid, external_userid) + abm.message_str = text + abm.session_id = external_userid + abm.type = MessageType.FRIEND_MESSAGE + else: + logger.warning(f"未实现的微信客服消息事件: {msg}") + return + await self.handle_msg(abm) + async def handle_msg(self, message: AstrBotMessage): message_event = WecomPlatformEvent( message_str=message.message_str, diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 05fc33da..791e4ff7 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -4,6 +4,7 @@ from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image, Record from wechatpy.enterprise import WeChatClient +from .wecom_kf_message import WeChatKFMessage from astrbot.api import logger @@ -52,19 +53,29 @@ class WecomPlatformEvent(AstrMessageEvent): if start + 2048 >= len(plain): result.append(plain[start:]) break - + # 向前搜索分割标点符号 end = min(start + 2048, len(plain)) cut_position = end for i in range(end, start, -1): - if i < len(plain) and plain[i-1] in ["。", "!", "?", ".", "!", "?", "\n", ";", ";"]: + if i < len(plain) and plain[i - 1] in [ + "。", + "!", + "?", + ".", + "!", + "?", + "\n", + ";", + ";", + ]: cut_position = i break - + # 没找到合适的位置分割, 直接切分 if cut_position == end and end < len(plain): cut_position = end - + result.append(plain[start:cut_position]) start = cut_position @@ -73,57 +84,77 @@ class WecomPlatformEvent(AstrMessageEvent): async def send(self, message: MessageChain): message_obj = self.message_obj - for comp in message.chain: - if isinstance(comp, Plain): - # Split long text messages if needed - plain_chunks = await self.split_plain(comp.text) - for chunk in plain_chunks: - self.client.message.send_text( - message_obj.self_id, message_obj.session_id, chunk - ) - await asyncio.sleep(0.5) # Avoid sending too fast - elif isinstance(comp, Image): - img_path = await comp.convert_to_file_path() - - with open(img_path, "rb") as f: - try: - response = self.client.media.upload("image", f) - except Exception as e: - logger.error(f"企业微信上传图片失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传图片失败: {e}") + is_wechat_kf = hasattr(self.message_obj.raw_message, "_wechat_kf_flag") + if is_wechat_kf: + kf_message_api = getattr(self.client, "kf_message", None) + if not kf_message_api: + logger.warning("未找到微信客服发送消息方法。") + return + assert isinstance(kf_message_api, WeChatKFMessage) + for comp in message.chain: + if isinstance(comp, Plain): + # Split long text messages if needed + plain_chunks = await self.split_plain(comp.text) + for chunk in plain_chunks: + # self.client.message.send_text( + # message_obj.self_id, message_obj.session_id, chunk + # ) + # kf_message_api.send_text() + await asyncio.sleep(0.5) # Avoid sending too fast + else: + logger.warning("没有实现的回复消息类型。") + else: + for comp in message.chain: + if isinstance(comp, Plain): + # Split long text messages if needed + plain_chunks = await self.split_plain(comp.text) + for chunk in plain_chunks: + self.client.message.send_text( + message_obj.self_id, message_obj.session_id, chunk ) - return - logger.info(f"企业微信上传图片返回: {response}") - self.client.message.send_image( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - # 转成amr - record_path_amr = f"data/temp/{uuid.uuid4()}.amr" - pydub.AudioSegment.from_wav(record_path).export( - record_path_amr, format="amr" - ) + await asyncio.sleep(0.5) # Avoid sending too fast + elif isinstance(comp, Image): + img_path = await comp.convert_to_file_path() - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"企业微信上传语音失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传语音失败: {e}") + with open(img_path, "rb") as f: + try: + response = self.client.media.upload("image", f) + except Exception as e: + logger.error(f"企业微信上传图片失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传图片失败: {e}") + ) + return + logger.info(f"企业微信上传图片返回: {response}") + self.client.message.send_image( + message_obj.self_id, + message_obj.session_id, + response["media_id"], ) - return - logger.info(f"企业微信上传语音返回: {response}") - self.client.message.send_voice( - message_obj.self_id, - message_obj.session_id, - response["media_id"], + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + # 转成amr + record_path_amr = f"data/temp/{uuid.uuid4()}.amr" + pydub.AudioSegment.from_wav(record_path).export( + record_path_amr, format="amr" ) + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"企业微信上传语音失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传语音失败: {e}") + ) + return + logger.info(f"企业微信上传语音返回: {response}") + self.client.message.send_voice( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) + await super().send(message) async def send_streaming(self, generator, use_fallback: bool = False): diff --git a/astrbot/core/platform/sources/wecom/wecom_kf.py b/astrbot/core/platform/sources/wecom/wecom_kf.py new file mode 100644 index 00000000..8ea5a6d5 --- /dev/null +++ b/astrbot/core/platform/sources/wecom/wecom_kf.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- + + +from wechatpy.client.api.base import BaseWeChatAPI + + +class WeChatKF(BaseWeChatAPI): + """ + 微信客服接口 + + https://work.weixin.qq.com/api/doc/90000/90135/94670 + """ + + def sync_msg(self, token, open_kfid, cursor="", limit=1000): + """ + 微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收) + 、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。 + 支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。 + + + :param token: 回调事件返回的token字段,10分钟内有效;可不填,如果不填接口有严格的频率限制。不多于128字节 + :param open_kfid: 客服帐号ID + :param cursor: 上一次调用时返回的next_cursor,第一次拉取可以不填。不多于64字节 + :param limit: 期望请求的数据量,默认值和最大值都为1000。 + 注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。 + :return: 接口调用结果 + """ + data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid} + return self._post("kf/sync_msg", data=data) + + def get_service_state(self, open_kfid, external_userid): + """ + 获取会话状态 + + ID 状态 说明 + 0 未处理 新会话接入。可选择:1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待 + 1 由智能助手接待 可使用API回复消息。可选择转入待接入池或者指定接待人员处理。 + 2 待接入池排队中 在待接入池中排队等待接待人员接入。可选择转为指定人员接待 + 3 由人工接待 人工接待中。可选择结束会话 + 4 已结束 会话已经结束。不允许变更会话状态,等待用户重新发起咨询 + + :param open_kfid: 客服帐号ID + :param external_userid: 微信客户的external_userid + :return: 接口调用结果 + """ + data = { + "open_kfid": open_kfid, + "external_userid": external_userid, + } + return self._post("kf/service_state/get", data=data) + + def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""): + """ + 变更会话状态 + + :param open_kfid: 客服帐号ID + :param external_userid: 微信客户的external_userid + :param service_state: 当前的会话状态,状态定义参考概述中的表格 + :return: 接口调用结果 + """ + data = { + "open_kfid": open_kfid, + "external_userid": external_userid, + "service_state": service_state, + } + if servicer_userid: + data["servicer_userid"] = servicer_userid + return self._post("kf/service_state/trans", data=data) + + def get_servicer_list(self, open_kfid): + """ + 获取接待人员列表 + + :param open_kfid: 客服帐号ID + :return: 接口调用结果 + """ + data = { + "open_kfid": open_kfid, + } + return self._get("kf/servicer/list", params=data) + + def add_servicer(self, open_kfid, userid_list): + """ + 添加接待人员 + 添加指定客服帐号的接待人员。 + + :param open_kfid: 客服帐号ID + :param userid_list: 接待人员userid列表 + :return: 接口调用结果 + """ + if not isinstance(userid_list, list): + userid_list = [userid_list] + + data = { + "open_kfid": open_kfid, + "userid_list": userid_list, + } + return self._post("kf/servicer/add", data=data) + + def del_servicer(self, open_kfid, userid_list): + """ + 删除接待人员 + 从客服帐号删除接待人员 + + :param open_kfid: 客服帐号ID + :param userid_list: 接待人员userid列表 + :return: 接口调用结果 + """ + if not isinstance(userid_list, list): + userid_list = [userid_list] + + data = { + "open_kfid": open_kfid, + "userid_list": userid_list, + } + return self._post("kf/servicer/del", data=data) + + def batchget_customer(self, external_userid_list): + """ + 客户基本信息获取 + + :param external_userid_list: external_userid列表 + :return: 接口调用结果 + """ + if not isinstance(external_userid_list, list): + external_userid_list = [external_userid_list] + + data = { + "external_userid_list": external_userid_list, + } + return self._post("kf/customer/batchget", data=data) + + def get_account_list(self): + """ + 获取客服帐号列表 + + :return: 接口调用结果 + """ + return self._get("kf/account/list") + + def add_contact_way(self, open_kfid, scene): + """ + 获取客服帐号链接 + + :param open_kfid: 客服帐号ID + :param scene: 场景值,字符串类型,由开发者自定义。不多于32字节;字符串取值范围(正则表达式):[0-9a-zA-Z_-]* + :return: 接口调用结果 + """ + data = {"open_kfid": open_kfid, "scene": scene} + return self._post("kf/add_contact_way", data=data) + + def get_upgrade_service_config(self): + """ + 获取配置的专员与客户群 + + :return: 接口调用结果 + """ + return self._get("kf/customer/get_upgrade_service_config") + + def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None): + """ + 为客户升级为专员或客户群服务 + + :param open_kfid: 客服帐号ID + :param external_userid: 微信客户的external_userid + :param service_type: 表示是升级到专员服务还是客户群服务。1:专员服务。2:客户群服务 + :param member: 推荐的服务专员,type等于1时有效 + :param groupchat: 推荐的客户群,type等于2时有效 + :return: 接口调用结果 + """ + + data = { + "open_kfid": open_kfid, + "external_userid": external_userid, + "type": service_type, + } + if service_type == 1: + data["member"] = member + else: + data["groupchat"] = groupchat + return self._post("kf/customer/upgrade_service", data=data) + + def cancel_upgrade_service(self, open_kfid, external_userid): + """ + 为客户取消推荐 + + :param open_kfid: 客服帐号ID + :param external_userid: 微信客户的external_userid + :return: 接口调用结果 + """ + + data = {"open_kfid": open_kfid, "external_userid": external_userid} + return self._post("kf/customer/cancel_upgrade_service", data=data) + + def send_msg_on_event(self, code, msgtype, msg_content, msgid=None): + """ + 当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。 + 支持发送消息类型:文本、菜单消息。 + + :param code: 事件响应消息对应的code。通过事件回调下发,仅可使用一次。 + :param msgtype: 消息类型。对不同的msgtype,有相应的结构描述,详见消息类型 + :param msg_content: 目前支持文本与菜单消息,具体查看文档 + :param msgid: 消息ID。如果请求参数指定了msgid,则原样返回,否则系统自动生成并返回。不多于32字节; + 字符串取值范围(正则表达式):[0-9a-zA-Z_-]* + :return: 接口调用结果 + """ + + data = {"code": code, "msgtype": msgtype} + if msgid: + data["msgid"] = msgid + data.update(msg_content) + return self._post("kf/send_msg_on_event", data=data) + + def get_corp_statistic(self, start_time, end_time, open_kfid=None): + """ + 获取「客户数据统计」企业汇总数据 + + :param start_time: 开始时间 + :param end_time: 结束时间 + :param open_kfid: 客服帐号ID + :return: 接口调用结果 + """ + data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time} + return self._post("kf/get_corp_statistic", data=data) + + def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None): + """ + 获取「客户数据统计」接待人员明细数据 + + :param start_time: 开始时间 + :param end_time: 结束时间 + :param open_kfid: 客服帐号ID + :param servicer_userid: 接待人员 + :return: 接口调用结果 + """ + data = { + "open_kfid": open_kfid, + "servicer_userid": servicer_userid, + "start_time": start_time, + "end_time": end_time, + } + return self._post("kf/get_servicer_statistic", data=data) + + def account_update(self, open_kfid, name, media_id): + """ + 修改客服账号 + + :param open_kfid: 客服帐号ID + :param name: 客服名称 + :param media_id: 客服头像临时素材 + + :return: 接口调用结果 + """ + data = {"open_kfid": open_kfid, "name": name, "media_id": media_id} + return self._post("kf/account/update", data=data) diff --git a/astrbot/core/platform/sources/wecom/wecom_kf_message.py b/astrbot/core/platform/sources/wecom/wecom_kf_message.py new file mode 100644 index 00000000..fd9b943a --- /dev/null +++ b/astrbot/core/platform/sources/wecom/wecom_kf_message.py @@ -0,0 +1,136 @@ + +from optionaldict import optionaldict + +from wechatpy.client.api.base import BaseWeChatAPI + +class WeChatKFMessage(BaseWeChatAPI): + """ + 发送微信客服消息 + + https://work.weixin.qq.com/api/doc/90000/90135/94677 + + 支持: + * 文本消息 + * 图片消息 + * 语音消息 + * 视频消息 + * 文件消息 + * 图文链接 + * 小程序 + * 菜单消息 + * 地理位置 + """ + + def send(self, user_id, open_kfid, msgid="", msg=None): + """ + 当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。 + 注意仅当微信客户在主动发送消息给客服后的48小时内,企业可发送消息给客户,最多可发送5条消息;若用户继续发送消息,企业可再次下发消息。 + 支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。 + + :param user_id: 指定接收消息的客户UserID + :param open_kfid: 指定发送消息的客服帐号ID + :param msgid: 指定消息ID + :param tag_ids: 标签ID列表。 + :param msg: 发送消息的 dict 对象 + :type msg: dict | None + :return: 接口调用结果 + """ + msg = msg or {} + data = { + "touser": user_id, + "open_kfid": open_kfid, + } + if msgid: + data["msgid"] = msgid + data.update(msg) + return self._post("kf/send_msg", data=data) + + def send_text(self, user_id, open_kfid, content, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={"msgtype": "text", "text": {"content": content}}, + ) + + def send_image(self, user_id, open_kfid, media_id, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={"msgtype": "image", "image": {"media_id": media_id}}, + ) + + def send_voice(self, user_id, open_kfid, media_id, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={"msgtype": "voice", "voice": {"media_id": media_id}}, + ) + + def send_video(self, user_id, open_kfid, media_id, msgid=""): + video_data = optionaldict() + video_data["media_id"] = media_id + + return self.send( + user_id, + open_kfid, + msgid, + msg={"msgtype": "video", "video": dict(video_data)}, + ) + + def send_file(self, user_id, open_kfid, media_id, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={"msgtype": "file", "file": {"media_id": media_id}}, + ) + + def send_articles_link(self, user_id, open_kfid, article, msgid=""): + articles_data = { + "title": article["title"], + "desc": article["desc"], + "url": article["url"], + "thumb_media_id": article["thumb_media_id"], + } + return self.send( + user_id, + open_kfid, + msgid, + msg={"msgtype": "news", "link": {"link": articles_data}}, + ) + + def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={ + "msgtype": "msgmenu", + "msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content}, + }, + ) + + def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={ + "msgtype": "location", + "msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude}, + }, + ) + + def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""): + return self.send( + user_id, + open_kfid, + msgid, + msg={ + "msgtype": "miniprogram", + "msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath}, + }, + ) From c36054ca1b1dbb7bc3e68d0bd5ae1b7335872a4e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 May 2025 11:53:50 -0400 Subject: [PATCH 02/19] =?UTF-8?q?=E2=9C=A8=20feat:=20=E5=BE=AE=E4=BF=A1?= =?UTF-8?q?=E5=AE=A2=E6=9C=8D=E6=94=AF=E6=8C=81=E6=96=87=E6=9C=AC=E6=B6=88?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/platform/sources/wecom/wecom_adapter.py | 15 ++++++++------- .../core/platform/sources/wecom/wecom_event.py | 10 +++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 78465880..234ba4e8 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -136,16 +136,17 @@ class WecomPlatformAdapter(Platform): self.config["corpid"].strip(), self.config["secret"].strip(), ) - # inject - self.wechat_kf_api = WeChatKF(client=self.client) - self.wechat_kf_message_api = WeChatKFMessage(self.client) - self.client.kf = self.wechat_kf_api - self.client.kf_message = self.wechat_kf_message_api - - self.client.API_BASE_URL = self.api_base_url # 微信客服 self.kf_name = self.config.get("kf_name", None) + if self.kf_name: + # inject + self.wechat_kf_api = WeChatKF(client=self.client) + self.wechat_kf_message_api = WeChatKFMessage(self.client) + self.client.kf = self.wechat_kf_api + self.client.kf_message = self.wechat_kf_message_api + + self.client.API_BASE_URL = self.api_base_url async def callback(msg: BaseMessage): if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 791e4ff7..01cec423 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -84,26 +84,26 @@ class WecomPlatformEvent(AstrMessageEvent): async def send(self, message: MessageChain): message_obj = self.message_obj - is_wechat_kf = hasattr(self.message_obj.raw_message, "_wechat_kf_flag") + is_wechat_kf = hasattr(self.client, "kf_message") if is_wechat_kf: + # 微信客服 kf_message_api = getattr(self.client, "kf_message", None) if not kf_message_api: logger.warning("未找到微信客服发送消息方法。") return assert isinstance(kf_message_api, WeChatKFMessage) + user_id = self.get_sender_id() for comp in message.chain: if isinstance(comp, Plain): # Split long text messages if needed plain_chunks = await self.split_plain(comp.text) for chunk in plain_chunks: - # self.client.message.send_text( - # message_obj.self_id, message_obj.session_id, chunk - # ) - # kf_message_api.send_text() + kf_message_api.send_text(user_id, self.get_self_id(), chunk) await asyncio.sleep(0.5) # Avoid sending too fast else: logger.warning("没有实现的回复消息类型。") else: + # 企业微信应用 for comp in message.chain: if isinstance(comp, Plain): # Split long text messages if needed From 66995db927661c4dc6bb9a5f60e4f7f967a3b1d8 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 May 2025 12:08:23 -0400 Subject: [PATCH 03/19] =?UTF-8?q?=E2=9C=A8=20feat:=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=BE=AE=E4=BF=A1=E5=AE=A2=E6=9C=8D=E5=9B=BE=E7=89=87=E6=B6=88?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../platform/sources/wecom/wecom_adapter.py | 18 ++++++++++---- .../platform/sources/wecom/wecom_event.py | 24 +++++++++++++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 234ba4e8..e4dd9077 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -293,18 +293,28 @@ class WecomPlatformAdapter(Platform): async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: msgtype = msg.get("msgtype", None) + external_userid = msg.get("external_userid", None) abm = AstrBotMessage() abm.raw_message = msg abm.raw_message["_wechat_kf_flag"] = None # 方便处理 abm.self_id = msg["open_kfid"] + abm.sender = MessageMember(external_userid, external_userid) + abm.session_id = external_userid + abm.type = MessageType.FRIEND_MESSAGE if msgtype == "text": - external_userid = msg.get("external_userid", None) text = msg.get("text", {}).get("content", "").strip() abm.message = [Plain(text=text)] - abm.sender = MessageMember(external_userid, external_userid) abm.message_str = text - abm.session_id = external_userid - abm.type = MessageType.FRIEND_MESSAGE + elif msgtype == "image": + media_id = msg.get("image", {}).get("media_id", "") + resp: Response = await asyncio.get_event_loop().run_in_executor( + None, self.client.media.download, media_id + ) + path = f"data/temp/wechat_kf_{media_id}.jpg" + with open(path, "wb") as f: + f.write(resp.content) + abm.message = [Image(file=path, url=path)] + abm.message_str = "[图片]" else: logger.warning(f"未实现的微信客服消息事件: {msg}") return diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 01cec423..507883e0 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -100,8 +100,26 @@ class WecomPlatformEvent(AstrMessageEvent): for chunk in plain_chunks: kf_message_api.send_text(user_id, self.get_self_id(), chunk) await asyncio.sleep(0.5) # Avoid sending too fast + elif isinstance(comp, Image): + img_path = await comp.convert_to_file_path() + + with open(img_path, "rb") as f: + try: + response = self.client.media.upload("image", f) + except Exception as e: + logger.error(f"微信客服上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传图片失败: {e}") + ) + return + logger.debug(f"微信客服上传图片返回: {response}") + kf_message_api.send_image( + user_id, + self.get_self_id(), + response["media_id"], + ) else: - logger.warning("没有实现的回复消息类型。") + logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") else: # 企业微信应用 for comp in message.chain: @@ -125,7 +143,7 @@ class WecomPlatformEvent(AstrMessageEvent): MessageChain().message(f"企业微信上传图片失败: {e}") ) return - logger.info(f"企业微信上传图片返回: {response}") + logger.debug(f"企业微信上传图片返回: {response}") self.client.message.send_image( message_obj.self_id, message_obj.session_id, @@ -154,6 +172,8 @@ class WecomPlatformEvent(AstrMessageEvent): message_obj.session_id, response["media_id"], ) + else: + logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") await super().send(message) From 7069b029293388887d4eda1e62c5fcd850a2682b Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 May 2025 12:11:55 -0400 Subject: [PATCH 04/19] chore: add license --- .../core/platform/sources/wecom/wecom_kf.py | 23 +++++++++++++++++++ .../sources/wecom/wecom_kf_message.py | 23 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/astrbot/core/platform/sources/wecom/wecom_kf.py b/astrbot/core/platform/sources/wecom/wecom_kf.py index 8ea5a6d5..316f6da3 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf.py @@ -1,5 +1,28 @@ # -*- coding: utf-8 -*- +""" +The MIT License (MIT) + +Copyright (c) 2014-2020 messense + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" from wechatpy.client.api.base import BaseWeChatAPI diff --git a/astrbot/core/platform/sources/wecom/wecom_kf_message.py b/astrbot/core/platform/sources/wecom/wecom_kf_message.py index fd9b943a..493d0405 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf_message.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf_message.py @@ -1,3 +1,26 @@ +""" +The MIT License (MIT) + +Copyright (c) 2014-2020 messense + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" from optionaldict import optionaldict From 3c8ec2f42e5e9003f580fe2ee4ef55c0b14b182a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 May 2025 12:47:21 -0400 Subject: [PATCH 05/19] =?UTF-8?q?=F0=9F=93=A6=20release:=20v3.5.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 2 +- changelogs/v3.5.7.md | 5 +++++ pyproject.toml | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 changelogs/v3.5.7.md diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 33e4bc3d..e5d3a201 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,7 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ -VERSION = "3.5.6" +VERSION = "3.5.7" DB_PATH = "data/data_v3.db" # 默认配置 diff --git a/changelogs/v3.5.7.md b/changelogs/v3.5.7.md new file mode 100644 index 00000000..f8097dfa --- /dev/null +++ b/changelogs/v3.5.7.md @@ -0,0 +1,5 @@ +# What's Changed + +> Gewechat 已经停止维护,此版本提供了 `微信客服` 的接入方式,可以在直接微信内聊天。这是微信官方推出的接入方式,因此没有风控风险。详见 [AstrBot 接入企业微信](https://astrbot.app/deploy/platform/wecom.html)。此接入方式处于测试阶段,有问题请及时在 GitHub 上提交 Issue。 + +1. 支持接入微信客服。 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 34778512..39ca7706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "AstrBot" -version = "3.5.6" +version = "3.5.7" description = "易上手的多平台 LLM 聊天机器人及开发框架" readme = "README.md" requires-python = ">=3.10" From dca1c0b0f3102e1bcc9dd6a26f7e5ef5583df5d3 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Tue, 6 May 2025 13:56:26 +0800 Subject: [PATCH 06/19] docs(README.md): update special thanks and platform --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6efd8463..b42f78b3 100644 --- a/README.md +++ b/README.md @@ -96,9 +96,10 @@ uv run main.py | -------- | ------- | ------- | ------ | | QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 | | QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 | -| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 | -| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 | -| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 | +| 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 | +| Telegram | ✔ | 私聊、群聊 | 文字、图片 | +| 企业微信 | ✔ | 私聊 | 文字、图片、语音 | +| 微信客服 | ✔ | 私聊 | 文字、图片 | | 飞书 | ✔ | 私聊、群聊 | 文字、图片 | | 钉钉 | ✔ | 私聊、群聊 | 文字、图片 | | 微信对话开放平台 | 🚧 | 计划内 | - | @@ -186,6 +187,9 @@ _✨ WebUI ✨_ +此外,本项目的诞生离不开以下开源项目: + +- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) ## ⭐ Star History From 9cc4e97a5385152a755605bf54d01a775452ad6c Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Tue, 6 May 2025 13:57:39 +0800 Subject: [PATCH 07/19] docs(README.md): update special thanks --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b42f78b3..10eea585 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,7 @@ _✨ WebUI ✨_ 此外,本项目的诞生离不开以下开源项目: - [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) +- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy) ## ⭐ Star History From c5bc7098984e9383ad8656b89d31f707be327163 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 6 May 2025 23:15:11 +0800 Subject: [PATCH 08/19] =?UTF-8?q?=F0=9F=8E=88=20perf:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=20openai=5Fsource=20=E6=96=B9=E6=B3=95=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/openai_source.py | 22 ++++++++++--------- packages/vpet/main.py | 19 ++++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 packages/vpet/main.py diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5399fbc3..f25ee3fc 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -21,7 +21,7 @@ from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List, AsyncGenerator from ..register import register_provider_adapter -from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult @register_provider_adapter( @@ -221,14 +221,16 @@ class ProviderOpenAIOfficial(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str] = [], + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts=[], - system_prompt=None, - tool_calls_result=None, + contexts: list=None, + system_prompt: str=None, + tool_calls_result: ToolCallsResult=None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" + if contexts is None: + contexts = [] new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -337,11 +339,11 @@ class ProviderOpenAIOfficial(Provider): async def text_chat( self, - prompt: str, - session_id: str = None, - image_urls: List[str] = [], - func_tool: FuncCall = None, - contexts=[], + prompt, + session_id = None, + image_urls = None, + func_tool = None, + contexts=None, system_prompt=None, tool_calls_result=None, **kwargs, diff --git a/packages/vpet/main.py b/packages/vpet/main.py new file mode 100644 index 00000000..6623cd6f --- /dev/null +++ b/packages/vpet/main.py @@ -0,0 +1,19 @@ +from astrbot.api.event import filter, AstrMessageEvent +from astrbot.api.star import Context, Star, register +from astrbot.api import logger + +@register("vpet", "AstrBot Team", "虚拟桌宠", "0.0.1") +class VPet(Star): + def __init__(self, context: Context): + super().__init__(context) + + async def initialize(self): + """可选择实现异步的插件初始化方法,当实例化该插件类之后会自动调用该方法。""" + + @filter.llm_tool("screenshot") + async def screenshot(self, event: AstrMessageEvent): + """Capture the screen and return the image.""" + + + async def terminate(self): + """可选择实现异步的插件销毁方法,当插件被卸载/停用时会调用。""" From 54c0dc1b2b482e72dd6aaeff2082ad1cb235a75f Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Wed, 7 May 2025 14:50:24 +0800 Subject: [PATCH 09/19] =?UTF-8?q?docs(README.md):=20=E4=B8=AA=E4=BA=BA?= =?UTF-8?q?=E5=BE=AE=E4=BF=A1=E6=8E=A5=E5=85=A5=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 10eea585..a36e1b0f 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=3600&style=for-the-badge&color=3b618e) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600) - English日本語查看文档 | @@ -28,11 +27,14 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。 -[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg?style=for-the-badge)](https://gitcode.com/Soulter/AstrBot) +> [!NOTE] +> +> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,我们正在评估其他方案(如 xxxbot 等)并将在数日内接入(很快!)。目前推荐微信用户暂时使用**微信官方**推出的企业微信接入方式和微信客服接入方式(版本 >= v3.5.7)。详情请前往 [#1443](https://github.com/AstrBotDevs/AstrBot/issues/1443) 讨论。 + ## ✨ 近期更新 1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器! From 752d13b1b1d4fdb830e3d2c36f93607602dcbe69 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Wed, 7 May 2025 19:04:24 +0800 Subject: [PATCH 10/19] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20gemini=5Fsou?= =?UTF-8?q?rce=20=E6=96=B9=E6=B3=95=E9=BB=98=E8=AE=A4=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index bf234953..fb47143d 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,7 +3,7 @@ import base64 import json import logging import random -from typing import Dict, List, Optional +from typing import Optional from collections.abc import AsyncGenerator from google import genai @@ -15,7 +15,7 @@ from astrbot import logger from astrbot.api.provider import Personality, Provider from astrbot.core.db import BaseDatabase from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.utils.io import download_image_by_url @@ -65,7 +65,7 @@ class ProviderGoogleGenAI(Provider): db_helper, default_persona, ) - self.api_keys: List = provider_config.get("key", []) + self.api_keys: list = provider_config.get("key", []) self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout: int = int(provider_config.get("timeout", 180)) @@ -99,7 +99,7 @@ class ProviderGoogleGenAI(Provider): and threshold_str in self.THRESHOLD_MAPPING ] - async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool: + async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool: """处理API错误,返回是否需要重试""" if e.code == 429 or "API key not valid" in e.message: keys.remove(self.chosen_api_key) @@ -126,7 +126,7 @@ class ProviderGoogleGenAI(Provider): payloads: dict, tools: Optional[FuncCall] = None, system_instruction: Optional[str] = None, - modalities: Optional[List[str]] = None, + modalities: Optional[list[str]] = None, temperature: float = 0.7, ) -> types.GenerateContentConfig: """准备查询配置""" @@ -195,7 +195,7 @@ class ProviderGoogleGenAI(Provider): ), ) - def _prepare_conversation(self, payloads: Dict) -> List[types.Content]: + def _prepare_conversation(self, payloads: dict) -> list[types.Content]: """准备 Gemini SDK 的 Content 列表""" def create_text_part(text: str) -> types.Part: @@ -220,7 +220,7 @@ class ProviderGoogleGenAI(Provider): else: contents.append(content_cls(parts=part)) - gemini_contents: List[types.Content] = [] + gemini_contents: list[types.Content] = [] native_tool_enabled = any( [ self.provider_config.get("gm_native_coderunner", False), @@ -464,13 +464,15 @@ class ProviderGoogleGenAI(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str] = None, + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts=[], - system_prompt=None, - tool_calls_result=None, + contexts: list = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: + if contexts is None: + contexts = [] new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -504,13 +506,15 @@ class ProviderGoogleGenAI(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str] = [], + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts=[], - system_prompt=None, - tool_calls_result=None, + contexts: str = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: + if contexts is None: + contexts = [] new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -556,14 +560,14 @@ class ProviderGoogleGenAI(Provider): def get_current_key(self) -> str: return self.chosen_api_key - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key): self.chosen_api_key = key self._init_client() - async def assemble_context(self, text: str, image_urls: List[str] = None): + async def assemble_context(self, text: str, image_urls: list[str] = None): """ 组装上下文。 """ From 626f94686b9ef882201139f0710893ba35855013 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 7 May 2025 08:57:22 -0400 Subject: [PATCH 11/19] =?UTF-8?q?=20=E2=9C=A8=20feat:=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=BE=AE=E4=BF=A1=E5=85=AC=E4=BC=97=E5=B9=B3=E5=8F=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 12 + astrbot/core/platform/manager.py | 2 + .../weixin_offacc_adapter.py | 252 ++++++++++++++++++ .../weixin_offacc_event.py | 147 ++++++++++ pyproject.toml | 1 + uv.lock | 13 +- 6 files changed, 426 insertions(+), 1 deletion(-) create mode 100644 astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py create mode 100644 astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index e5d3a201..c45d6247 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -151,6 +151,18 @@ CONFIG_METADATA_2 = { "host": "这里填写你的局域网IP或者公网服务器IP", "port": 11451, }, + "weixin_official_account(微信公众平台)": { + "id": "weixin_official_account", + "type": "weixin_official_account", + "enable": False, + "appid": "wx4cb77256a17de10a", + "secret": "", + "token": "", + "encoding_aes_key": "", + "api_base_url": "https://api.weixin.qq.com/cgi-bin/", + "callback_server_host": "0.0.0.0", + "port": 6194, + }, "wecom(企业微信)": { "id": "wecom", "type": "wecom", diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 22a06b73..4ac57544 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -72,6 +72,8 @@ class PlatformManager: from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401 case "wecom": from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401 + case "weixin_official_account": + from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa except (ImportError, ModuleNotFoundError) as e: logger.error( f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。" diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py new file mode 100644 index 00000000..d7463d4d --- /dev/null +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -0,0 +1,252 @@ +import sys +import uuid +import asyncio +import quart + +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + PlatformMetadata, + MessageType, +) +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain, Image, Record +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.api.platform import register_platform_adapter +from astrbot.core import logger +from requests import Response + +from wechatpy.utils import check_signature +from wechatpy.crypto import WeChatCrypto +from wechatpy import WeChatClient +from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage +from wechatpy.exceptions import InvalidSignatureException +from wechatpy import parse_message +from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class WecomServer: + def __init__(self, event_queue: asyncio.Queue, config: dict): + self.server = quart.Quart(__name__) + self.port = int(config.get("port")) + self.callback_server_host = config.get("callback_server_host", "0.0.0.0") + self.token = config.get("token") + self.encoding_aes_key = config.get("encoding_aes_key") + self.appid = config.get("appid") + self.server.add_url_rule( + "/callback/command", view_func=self.verify, methods=["GET"] + ) + self.server.add_url_rule( + "/callback/command", view_func=self.callback_command, methods=["POST"] + ) + self.crypto = WeChatCrypto(self.token, self.encoding_aes_key, self.appid) + + self.event_queue = event_queue + + self.callback = None + self.shutdown_event = asyncio.Event() + + async def verify(self): + logger.info(f"验证请求有效性: {quart.request.args}") + + args = quart.request.args + if not args.get("signature", None): + logger.error("未知的响应,请检查回调地址是否填写正确。") + return "err" + try: + check_signature( + self.token, + args.get("signature"), + args.get("timestamp"), + args.get("nonce"), + ) + logger.info("验证请求有效性成功。") + return args.get("echostr", "empty") + except InvalidSignatureException: + logger.error("验证请求有效性失败,签名异常,请检查配置。") + return "err" + + async def callback_command(self): + data = await quart.request.get_data() + msg_signature = quart.request.args.get("msg_signature") + timestamp = quart.request.args.get("timestamp") + nonce = quart.request.args.get("nonce") + try: + xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce) + except InvalidSignatureException: + logger.error("解密失败,签名异常,请检查配置。") + raise + else: + msg = parse_message(xml) + logger.info(f"解析成功: {msg}") + + if self.callback: + await self.callback(msg) + + return "success" + + async def start_polling(self): + logger.info( + f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。" + ) + await self.server.run_task( + host=self.callback_server_host, + port=self.port, + shutdown_trigger=self.shutdown_trigger, + ) + + async def shutdown_trigger(self): + await self.shutdown_event.wait() + + +@register_platform_adapter("weixin_official_account", "微信公众平台 适配器") +class WeixinOfficialAccountPlatformAdapter(Platform): + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: + super().__init__(event_queue) + self.config = platform_config + self.settingss = platform_settings + self.client_self_id = uuid.uuid4().hex[:8] + self.api_base_url = platform_config.get( + "api_base_url", "https://api.weixin.qq.com/cgi-bin/" + ) + + if not self.api_base_url: + self.api_base_url = "https://api.weixin.qq.com/cgi-bin/" + + if self.api_base_url.endswith("/"): + self.api_base_url = self.api_base_url[:-1] + if not self.api_base_url.endswith("/cgi-bin"): + self.api_base_url += "/cgi-bin" + + if not self.api_base_url.endswith("/"): + self.api_base_url += "/" + + self.server = WecomServer(self._event_queue, self.config) + + self.client = WeChatClient( + self.config["appid"].strip(), + self.config["secret"].strip(), + ) + + async def callback(msg): + try: + await self.convert_message(msg) + except Exception as e: + logger.error(f"转换消息时出现异常: {e}") + + self.server.callback = callback + + @override + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): + await super().send_by_session(session, message_chain) + + @override + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + "weixin_official_account", + "微信公众平台 适配器", + ) + + @override + async def run(self): + await self.server.start_polling() + + async def convert_message(self, msg) -> AstrBotMessage | None: + abm = AstrBotMessage() + if isinstance(msg, TextMessage): + abm.message_str = msg.content + abm.self_id = str(msg.target) + abm.message = [Plain(msg.content)] + abm.type = MessageType.FRIEND_MESSAGE + abm.sender = MessageMember( + msg.source, + msg.source, + ) + abm.message_id = msg.id + abm.timestamp = msg.time + abm.session_id = abm.sender.user_id + abm.raw_message = msg + elif msg.type == "image": + assert isinstance(msg, ImageMessage) + abm.message_str = "[图片]" + abm.self_id = str(msg.target) + abm.message = [Image(file=msg.image, url=msg.image)] + abm.type = MessageType.FRIEND_MESSAGE + abm.sender = MessageMember( + msg.source, + msg.source, + ) + abm.message_id = msg.id + abm.timestamp = msg.time + abm.session_id = abm.sender.user_id + abm.raw_message = msg + elif msg.type == "voice": + assert isinstance(msg, VoiceMessage) + + resp: Response = await asyncio.get_event_loop().run_in_executor( + None, self.client.media.download, msg.media_id + ) + path = f"data/temp/wecom_{msg.media_id}.amr" + with open(path, "wb") as f: + f.write(resp.content) + + try: + from pydub import AudioSegment + + path_wav = f"data/temp/wecom_{msg.media_id}.wav" + audio = AudioSegment.from_file(path) + audio.export(path_wav, format="wav") + except Exception as e: + logger.error(f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。") + path_wav = path + return + + abm.message_str = "" + abm.self_id = str(msg.target) + abm.message = [Record(file=path_wav, url=path_wav)] + abm.type = MessageType.FRIEND_MESSAGE + abm.sender = MessageMember( + msg.source, + msg.source, + ) + abm.message_id = msg.id + abm.timestamp = msg.time + abm.session_id = abm.sender.user_id + abm.raw_message = msg + else: + logger.warning(f"暂未实现的事件: {msg.type}") + return + + logger.info(f"abm: {abm}") + await self.handle_msg(abm) + + async def handle_msg(self, message: AstrBotMessage): + message_event = WeixinOfficialAccountPlatformEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + client=self.client, + ) + self.commit_event(message_event) + + def get_client(self) -> WeChatClient: + return self.client + + async def terminate(self): + self.server.shutdown_event.set() + try: + await self.server.server.shutdown() + except Exception as _: + pass + logger.info("微信公众平台 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py new file mode 100644 index 00000000..9519cd49 --- /dev/null +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -0,0 +1,147 @@ +import uuid +import asyncio +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.api.message_components import Plain, Image, Record +from wechatpy import WeChatClient + +from astrbot.api import logger + +try: + import pydub +except Exception: + logger.warning( + "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。" + ) + pass + + +class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + client: WeChatClient, + ): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.client = client + + @staticmethod + async def send_with_client( + client: WeChatClient, message: MessageChain, user_name: str + ): + pass + + async def split_plain(self, plain: str) -> list[str]: + """将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符 + + Args: + plain (str): 要分割的长文本 + Returns: + list[str]: 分割后的文本列表 + """ + if len(plain) <= 2048: + return [plain] + else: + result = [] + start = 0 + while start < len(plain): + # 剩下的字符串长度<2048时结束 + if start + 2048 >= len(plain): + result.append(plain[start:]) + break + + # 向前搜索分割标点符号 + end = min(start + 2048, len(plain)) + cut_position = end + for i in range(end, start, -1): + if i < len(plain) and plain[i - 1] in [ + "。", + "!", + "?", + ".", + "!", + "?", + "\n", + ";", + ";", + ]: + cut_position = i + break + + # 没找到合适的位置分割, 直接切分 + if cut_position == end and end < len(plain): + cut_position = end + + result.append(plain[start:cut_position]) + start = cut_position + + return result + + async def send(self, message: MessageChain): + message_obj = self.message_obj + for comp in message.chain: + if isinstance(comp, Plain): + # Split long text messages if needed + plain_chunks = await self.split_plain(comp.text) + for chunk in plain_chunks: + self.client.message.send_text(message_obj.sender.user_id, chunk) + await asyncio.sleep(0.5) # Avoid sending too fast + elif isinstance(comp, Image): + img_path = await comp.convert_to_file_path() + + with open(img_path, "rb") as f: + try: + response = self.client.media.upload("image", f) + except Exception as e: + logger.error(f"微信公众平台上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信公众平台上传图片失败: {e}") + ) + return + logger.debug(f"微信公众平台上传图片返回: {response}") + self.client.message.send_image( + message_obj.sender.user_id, + response["media_id"], + ) + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + # 转成amr + record_path_amr = f"data/temp/{uuid.uuid4()}.amr" + pydub.AudioSegment.from_wav(record_path).export( + record_path_amr, format="amr" + ) + + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"微信公众平台上传语音失败: {e}") + await self.send( + MessageChain().message(f"微信公众平台上传语音失败: {e}") + ) + return + logger.info(f"微信公众平台上传语音返回: {response}") + self.client.message.send_voice( + message_obj.sender.user_id, + response["media_id"], + ) + else: + logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") + + await super().send(message) + + async def send_streaming(self, generator, use_fallback: bool = False): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator, use_fallback) diff --git a/pyproject.toml b/pyproject.toml index 39ca7706..66474019 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "pip>=25.0.1", "psutil>=5.8.0", "pydantic~=2.10.3", + "pydub>=0.25.1", "pyjwt>=2.10.1", "python-telegram-bot>=22.0", "qq-botpy>=1.2.1", diff --git a/uv.lock b/uv.lock index 7cdd6be3..7f40f2e5 100644 --- a/uv.lock +++ b/uv.lock @@ -192,7 +192,7 @@ wheels = [ [[package]] name = "astrbot" -version = "3.4.39" +version = "3.5.7" source = { editable = "." } dependencies = [ { name = "aiocqhttp" }, @@ -220,6 +220,7 @@ dependencies = [ { name = "pip" }, { name = "psutil" }, { name = "pydantic" }, + { name = "pydub" }, { name = "pyjwt" }, { name = "python-telegram-bot" }, { name = "qq-botpy" }, @@ -257,6 +258,7 @@ requires-dist = [ { name = "pip", specifier = ">=25.0.1" }, { name = "psutil", specifier = ">=5.8.0" }, { name = "pydantic", specifier = "~=2.10.3" }, + { name = "pydub", specifier = ">=0.25.1" }, { name = "pyjwt", specifier = ">=2.10.1" }, { name = "python-telegram-bot", specifier = ">=22.0" }, { name = "qq-botpy", specifier = ">=1.2.1" }, @@ -1658,6 +1660,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/5f/d6d641b490fd3ec2c4c13b4244d68deea3a1b970a97be64f34fb5504ff72/pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef", size = 44356 }, ] +[[package]] +name = "pydub" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, +] + [[package]] name = "pyjwt" version = "2.10.1" From f40fa0eceab596632ffdb8ef592124e7387b80a3 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 7 May 2025 08:59:48 -0400 Subject: [PATCH 12/19] chore: remove useless config --- astrbot/core/config/default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index c45d6247..5e06c19d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -155,7 +155,7 @@ CONFIG_METADATA_2 = { "id": "weixin_official_account", "type": "weixin_official_account", "enable": False, - "appid": "wx4cb77256a17de10a", + "appid": "", "secret": "", "token": "", "encoding_aes_key": "", From e6bd7524c1acb96e997f5311ae6321da6bd9543c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 7 May 2025 09:49:07 -0400 Subject: [PATCH 13/19] =?UTF-8?q?=F0=9F=8E=88=20perf:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=20persona=20=E9=94=99=E8=AF=AF=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/astrbot/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 9dcd4a68..2f7b8ee3 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1017,6 +1017,8 @@ UID: {user_id} 此 ID 可用于设置管理员。 conversation = await self.context.conversation_manager.get_conversation( message.unified_msg_origin, cid ) + if not conversation: + message.set_result(MessageEventResult().message("请先进入一个对话。可以使用 /new 创建。")) if not conversation.persona_id and not conversation.persona_id == "[%None]": curr_persona_name = ( self.context.provider_manager.selected_default_persona["name"] From 3ace4199a1107bc726c3f02468256e7c4e74317e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 7 May 2025 09:51:45 -0400 Subject: [PATCH 14/19] =?UTF-8?q?=F0=9F=93=A6=20release:=20v3.5.8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 2 +- changelogs/v3.5.8.md | 5 +++++ pyproject.toml | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 changelogs/v3.5.8.md diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5e06c19d..0aaf2a6d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,7 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ -VERSION = "3.5.7" +VERSION = "3.5.8" DB_PATH = "data/data_v3.db" # 默认配置 diff --git a/changelogs/v3.5.8.md b/changelogs/v3.5.8.md new file mode 100644 index 00000000..bbab9fee --- /dev/null +++ b/changelogs/v3.5.8.md @@ -0,0 +1,5 @@ +# What's Changed + +1. 支持接入微信公众平台,详见 [AstrBot - 微信公众平台](https://astrbot.app/deploy/platform/weixin-official-account.html) @Soulter +2. 优化 gemini_source 方法默认参数 @Raven95678 +3. 优化 persona 错误显示 @Soulter \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 66474019..d7e7f8a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "AstrBot" -version = "3.5.7" +version = "3.5.8" description = "易上手的多平台 LLM 聊天机器人及开发框架" readme = "README.md" requires-python = ">=3.10" From 4a62f877dfc2288437d7094c3c474aa1a4c38ca7 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 9 May 2025 10:45:50 +0800 Subject: [PATCH 15/19] =?UTF-8?q?=F0=9F=90=9B=20fix:=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=8D=95=E7=8B=AC=E6=96=87=E4=BB=B6=E5=8F=91=E9=80=81=E6=97=B6?= =?UTF-8?q?=E8=A2=AB=E8=AE=A4=E4=B8=BA=E6=98=AF=E7=A9=BA=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E6=96=87=E4=BB=B6=E6=97=A0=E6=B3=95=E5=8F=91?= =?UTF-8?q?=E9=80=81=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/message/components.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 74538d09..73bfa727 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -559,12 +559,13 @@ class File(BaseMessageComponent): type: ComponentType = "File" name: T.Optional[str] = "" # 名字 - _file: T.Optional[str] = "" # 本地路径 + file_: T.Optional[str] = "" # 本地路径 url: T.Optional[str] = "" # url _downloaded: bool = False # 是否已经下载 - def __init__(self, name: str = "", file: str = "", url: str = ""): - super().__init__(name=name, _file=file, url=url) + def __init__(self, name: str, file: str, url: str = ""): + """文件消息段。一般情况下请直接使用 file 参数即可,可以传入文件路径或 URL,AstrBot 会自动识别。""" + super().__init__(name=name, file_=file, url=url) @property def file(self) -> str: @@ -574,8 +575,8 @@ class File(BaseMessageComponent): Returns: str: 文件路径 """ - if self._file and os.path.exists(self._file): - return self._file + if self.file_ and os.path.exists(self.file_): + return self.file_ if self.url and not self._downloaded: try: @@ -589,8 +590,8 @@ class File(BaseMessageComponent): # 等待下载完成 loop.run_until_complete(self._download_file()) - if self._file and os.path.exists(self._file): - return self._file + if self.file_ and os.path.exists(self.file_): + return self.file_ except Exception as e: logger.error(f"文件下载失败: {e}") @@ -607,7 +608,7 @@ class File(BaseMessageComponent): if value.startswith("http://") or value.startswith("https://"): self.url = value else: - self._file = value + self.file_ = value async def get_file(self) -> str: """ @@ -617,12 +618,12 @@ class File(BaseMessageComponent): Returns: str: 文件路径 """ - if self._file and os.path.exists(self._file): - return self._file + if self.file_ and os.path.exists(self.file_): + return self.file_ if self.url: await self._download_file() - return self._file + return self.file_ return "" @@ -637,7 +638,7 @@ class File(BaseMessageComponent): await download_file(self.url, file_path) - self._file = file_path + self.file_ = file_path self._downloaded = True From 790b924e57fb6e6b30b2d6e8c2647e433bb8e69e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 8 May 2025 23:41:08 -0400 Subject: [PATCH 16/19] =?UTF-8?q?refactor:=20QQ=20=E9=87=87=E7=94=A8=20htt?= =?UTF-8?q?p=20=E5=9B=9E=E8=B0=83=E7=9A=84=E6=96=B9=E5=BC=8F=E4=B8=8A?= =?UTF-8?q?=E6=8A=A5=E6=96=87=E4=BB=B6=E6=B6=88=E6=81=AF=E6=AE=B5=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E6=96=87=E4=BB=B6=E4=BF=A1=E6=81=AF=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix: 修复 Lagrange 下合并转发消息失败的问题 Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- astrbot/core/__init__.py | 18 +++--- astrbot/core/config/default.py | 7 +++ astrbot/core/file_token_service.py | 39 +++++++++++++ astrbot/core/message/components.py | 58 ++++++++++--------- astrbot/core/pipeline/respond/stage.py | 23 +------- .../aiocqhttp/aiocqhttp_message_event.py | 44 +++++++++----- astrbot/dashboard/routes/__init__.py | 4 +- astrbot/dashboard/routes/file.py | 27 +++++++++ astrbot/dashboard/server.py | 6 +- 9 files changed, 149 insertions(+), 77 deletions(-) create mode 100644 astrbot/core/file_token_service.py create mode 100644 astrbot/dashboard/routes/file.py diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 59e61d73..a9b1fafd 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -7,27 +7,27 @@ from astrbot.core.utils.pip_installer import PipInstaller from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.config.default import DB_PATH from astrbot.core.config import AstrBotConfig +from astrbot.core.file_token_service import FileTokenService # 初始化数据存储文件夹 os.makedirs("data", exist_ok=True) +WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" +DEMO_MODE = os.getenv("DEMO_MODE", False) + astrbot_config = AstrBotConfig() t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") html_renderer = HtmlRenderer(t2i_base_url) logger = LogManager.GetLogger(log_name="astrbot") - -if os.environ.get("TESTING", ""): - logger.setLevel("DEBUG") - db_helper = SQLiteDatabase(DB_PATH) -sp = ( - SharedPreferences() -) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 +# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 +sp = SharedPreferences() +# 文件令牌服务 +file_token_service = FileTokenService() pip_installer = PipInstaller( astrbot_config.get("pip_install_arg", ""), astrbot_config.get("pypi_index_url", None), ) web_chat_queue = asyncio.Queue(maxsize=32) web_chat_back_queue = asyncio.Queue(maxsize=32) -WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" -DEMO_MODE = os.getenv("DEMO_MODE", False) + diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 0aaf2a6d..341d105a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -104,6 +104,7 @@ DEFAULT_CONFIG = { "knowledge_db": {}, "persona": [], "timezone": "", + "callback_api_base": "", } @@ -1283,6 +1284,12 @@ CONFIG_METADATA_2 = { "obvious_hint": True, "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab", }, + "callback_api_base": { + "description": "对外可达的回调接口地址", + "type": "string", + "obvious_hint": True, + "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。" + }, "log_level": { "description": "控制台日志级别", "type": "string", diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py new file mode 100644 index 00000000..d1a8f611 --- /dev/null +++ b/astrbot/core/file_token_service.py @@ -0,0 +1,39 @@ +import asyncio +import os +import uuid + + +class FileTokenService: + """维护一个简单的基于令牌的文件下载服务""" + + def __init__(self): + self.lock = asyncio.Lock() + self.staged_files = {} + + async def register_file(self, file_path: str) -> str: + """向令牌服务注册一个文件。 + + Args: + file_path(str): 文件路径 + + Returns: + str: 一个单次令牌 + + Raises: + FileNotFoundError: 当路径不存在时抛出。 + """ + async with self.lock: + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + file_token = str(uuid.uuid4()) + self.staged_files[file_token] = file_path + return file_token + + async def handle_file(self, file_token: str) -> str: + async with self.lock: + if file_token not in self.staged_files: + raise KeyError(f"无效文件 token: {file_token}") + file_path = self.staged_files.pop(file_token, None) + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + return file_path diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 73bfa727..39a38f7b 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -462,10 +462,10 @@ class Node(BaseMessageComponent): type: ComponentType = "Node" id: T.Optional[int] = 0 # 忽略 name: T.Optional[str] = "" # qq昵称 - uin: T.Optional[int] = 0 # qq号 + uin: T.Optional[str] = "0" # qq号 content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表 seq: T.Optional[T.Union[str, list]] = "" # 忽略 - time: T.Optional[int] = 0 + time: T.Optional[int] = 0 # 忽略 def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_): if isinstance(content, list): @@ -494,8 +494,14 @@ class Nodes(BaseMessageComponent): super().__init__(nodes=nodes, **_) def toDict(self): - return {"messages": [node.toDict() for node in self.nodes]} - + ret = { + "messages": [], + } + for node in self.nodes: + d = node.toDict() + d["data"]["uin"] = str(node.uin) # 转为字符串 + ret["messages"].append(d) + return ret class Xml(BaseMessageComponent): type: ComponentType = "Xml" @@ -561,10 +567,9 @@ class File(BaseMessageComponent): name: T.Optional[str] = "" # 名字 file_: T.Optional[str] = "" # 本地路径 url: T.Optional[str] = "" # url - _downloaded: bool = False # 是否已经下载 - def __init__(self, name: str, file: str, url: str = ""): - """文件消息段。一般情况下请直接使用 file 参数即可,可以传入文件路径或 URL,AstrBot 会自动识别。""" + def __init__(self, name: str, file: str = "", url: str = ""): + """文件消息段。""" super().__init__(name=name, file_=file, url=url) @property @@ -576,22 +581,24 @@ class File(BaseMessageComponent): str: 文件路径 """ if self.file_ and os.path.exists(self.file_): - return self.file_ + return os.path.abspath(self.file_) if self.url and not self._downloaded: try: loop = asyncio.get_event_loop() if loop.is_running(): - logger.warning( - "不可以在异步上下文中同步等待下载! 请使用 await get_file() 代替" - ) + logger.warning(( + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段" + )) return "" else: # 等待下载完成 loop.run_until_complete(self._download_file()) if self.file_ and os.path.exists(self.file_): - return self.file_ + return os.path.abspath(self.file_) except Exception as e: logger.error(f"文件下载失败: {e}") @@ -610,36 +617,31 @@ class File(BaseMessageComponent): else: self.file_ = value - async def get_file(self) -> str: - """ - 异步获取文件 - To 插件开发者: 请注意在使用后清理下载的文件, 以免占用过多空间 + async def get_file(self, allow_return_url: bool=False) -> str: + """异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间 + Args: + allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。 + 注意,如果为 True,也可能返回文件路径。 Returns: - str: 文件路径 + str: 文件路径或者 http 下载链接 """ if self.file_ and os.path.exists(self.file_): - return self.file_ + return os.path.abspath(self.file_) if self.url: await self._download_file() - return self.file_ + return os.path.abspath(self.file_) return "" async def _download_file(self): """下载文件""" - if self._downloaded: - return - - os.makedirs("data/download", exist_ok=True) + os.makedirs("data/temp", exist_ok=True) filename = self.name or f"{uuid.uuid4().hex}" - file_path = f"data/download/{filename}" - + file_path = f"data/temp/{filename}" await download_file(self.url, file_path) - - self.file_ = file_path - self._downloaded = True + self.file_ = os.path.abspath(file_path) class WechatEmoji(BaseMessageComponent): diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 776f4a62..bff94a64 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -26,33 +26,14 @@ class RespondStage(Stage): Comp.Record: lambda comp: bool(comp.file), # 语音 Comp.Video: lambda comp: bool(comp.file), # 视频 Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ - Comp.AtAll: lambda comp: True, # @所有人 - Comp.RPS: lambda comp: True, # 不知道是啥(未完成) - Comp.Dice: lambda comp: True, # 骰子(未完成) - Comp.Shake: lambda comp: True, # 摇一摇(未完成) - Comp.Anonymous: lambda comp: True, # 匿名(未完成) - Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享 - Comp.Contact: lambda comp: True, # 联系人(未完成) - Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置 - Comp.Music: lambda comp: bool(comp._type) - and bool(comp.url) - and bool(comp.audio), # 音乐 Comp.Image: lambda comp: bool(comp.file), # 图片 Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 - Comp.RedBag: lambda comp: bool(comp.title), # 红包 Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳 - Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发 Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 - Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML - Comp.Json: lambda comp: bool(comp.data), # JSON - Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片 - Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成 - Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息 - Comp.File: lambda comp: bool(comp.file), # 文件 - Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情 + Comp.File: lambda comp: bool(comp.file_ or comp.url), } async def initialize(self, ctx: PipelineContext): @@ -129,8 +110,6 @@ class RespondStage(Stage): if comp_type in self._component_validators: if self._component_validators[comp_type](comp): return False - else: - logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}") # 如果所有组件都为空 return True diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 4acb677d..068a8bf3 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,8 +3,9 @@ import re from typing import AsyncGenerator, Dict, List from aiocqhttp import CQHttp from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record +from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File from astrbot.api.platform import Group, MessageMember +from astrbot.core import file_token_service, astrbot_config, logger class AiocqhttpMessageEvent(AstrMessageEvent): @@ -34,24 +35,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent): } elif isinstance(segment, At): d["data"] = { - "qq": str(segment.qq) # 转换为字符串 + "qq": str(segment.qq), # 转换为字符串 } ret.append(d) return ret async def send(self, message: MessageChain): - ret = await AiocqhttpMessageEvent._parse_onebot_json(message) - - if not ret: - return - - send_one_by_one = False - for seg in message.chain: - if isinstance(seg, (Node, Nodes)): - # 转发消息不能和普通消息混在一起发送 - send_one_by_one = True - break - + # 转发消息、文件消息不能和普通消息混在一起发送 + send_one_by_one = any( + isinstance(seg, (Node, Nodes, File)) for seg in message.chain + ) if send_one_by_one: for seg in message.chain: if isinstance(seg, (Node, Nodes)): @@ -70,6 +63,26 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await self.bot.call_action( "send_private_forward_msg", **payload ) + elif isinstance(seg, File): + d = seg.toDict() + url_or_path = await seg.get_file(allow_return_url=True) + if url_or_path.startswith("http"): + payload_file = url_or_path + elif callback_host := astrbot_config.get("callback_api_base"): + callback_host = str(callback_host).removesuffix("/") + token = await file_token_service.register_file(url_or_path) + payload_file = f"{callback_host}/api/file/{token}" + logger.debug(f"Generated file callback link: {payload_file}") + else: + payload_file = url_or_path + d["data"] = { + "name": seg.name, + "file": payload_file, + } + await self.bot.send( + self.message_obj.raw_message, + [d], + ) else: await self.bot.send( self.message_obj.raw_message, @@ -79,6 +92,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent): ) await asyncio.sleep(0.5) else: + ret = await AiocqhttpMessageEvent._parse_onebot_json(message) + if not ret: + return await self.bot.send(self.message_obj.raw_message, ret) await super().send(message) diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 3e24583e..f9309c3e 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -8,6 +8,7 @@ from .static_file import StaticFileRoute from .chat import ChatRoute from .tools import ToolsRoute # 导入新的ToolsRoute from .conversation import ConversationRoute +from .file import FileRoute __all__ = [ @@ -19,6 +20,7 @@ __all__ = [ "LogRoute", "StaticFileRoute", "ChatRoute", - "ToolsRoute", # 添加新的ToolsRoute + "ToolsRoute", "ConversationRoute", + "FileRoute", ] diff --git a/astrbot/dashboard/routes/file.py b/astrbot/dashboard/routes/file.py new file mode 100644 index 00000000..44c92ad0 --- /dev/null +++ b/astrbot/dashboard/routes/file.py @@ -0,0 +1,27 @@ +from .route import Route, RouteContext +from astrbot import logger +from quart import abort, send_file +from astrbot.core import file_token_service + + +class FileRoute(Route): + def __init__( + self, + context: RouteContext, + ) -> None: + super().__init__(context) + self.routes = { + "/file/": ("GET", self.serve_file), + } + self.register_routes() + + async def serve_file(self, file_token: str): + try: + file_path = await file_token_service.handle_file(file_token) + return await send_file(file_path) + except FileNotFoundError as e: + logger.warning(str(e)) + return abort(404) + except KeyError as e: + logger.warning(str(e)) + return abort(404) diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 5d131080..c85ada4e 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -52,15 +52,15 @@ class AstrBotDashboard: self.chat_route = ChatRoute(self.context, db, core_lifecycle) self.tools_root = ToolsRoute(self.context, core_lifecycle) self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) + self.file_route = FileRoute(self.context) self.shutdown_event = shutdown_event async def auth_middleware(self): if not request.path.startswith("/api"): return - if request.path == "/api/auth/login": - return - if request.path == "/api/chat/get_file": + allowed_endpoints = ["/api/auth/login", "/api/chat/get_file", "/api/file"] + if any(request.path.startswith(prefix) for prefix in allowed_endpoints): return # claim jwt token = request.headers.get("Authorization") From d9d94af022d2b10e23f435608cbe98148866df83 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 9 May 2025 04:00:12 -0400 Subject: [PATCH 17/19] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E5=A4=84=E7=90=86=E5=92=8C=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/dashboard/routes/file.py | 5 +---- astrbot/dashboard/routes/static_file.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/astrbot/dashboard/routes/file.py b/astrbot/dashboard/routes/file.py index 44c92ad0..8ea73d08 100644 --- a/astrbot/dashboard/routes/file.py +++ b/astrbot/dashboard/routes/file.py @@ -19,9 +19,6 @@ class FileRoute(Route): try: file_path = await file_token_service.handle_file(file_token) return await send_file(file_path) - except FileNotFoundError as e: - logger.warning(str(e)) - return abort(404) - except KeyError as e: + except (FileNotFoundError, KeyError) as e: logger.warning(str(e)) return abort(404) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 4503a28e..729fe854 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -28,7 +28,7 @@ class StaticFileRoute(Route): @self.app.errorhandler(404) async def page_not_found(e): - return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。" + return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" async def index(self): return await self.app.send_static_file("index.html") From 7fd765421f90f328427dd7b5e57c5a9618bcda79 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Fri, 9 May 2025 09:58:37 +0000 Subject: [PATCH 18/19] fix: [File] remove unused tags "_downloaded" --- astrbot/core/message/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 39a38f7b..718fd30f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -583,7 +583,7 @@ class File(BaseMessageComponent): if self.file_ and os.path.exists(self.file_): return os.path.abspath(self.file_) - if self.url and not self._downloaded: + if self.url: try: loop = asyncio.get_event_loop() if loop.is_running(): From 5b8f73cdd7ba89913a89604fcf71d6b786e3bb7b Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 9 May 2025 07:29:11 -0400 Subject: [PATCH 19/19] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E4=BB=A4?= =?UTF-8?q?=E7=89=8C=E8=B6=85=E6=97=B6=E6=97=B6=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/file_token_service.py | 45 ++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index d1a8f611..2ed46d43 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -1,39 +1,68 @@ import asyncio import os import uuid +import time class FileTokenService: - """维护一个简单的基于令牌的文件下载服务""" + """维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。""" - def __init__(self): + def __init__(self, default_timeout: float = 300): self.lock = asyncio.Lock() - self.staged_files = {} + self.staged_files = {} # token: (file_path, expire_time) + self.default_timeout = default_timeout - async def register_file(self, file_path: str) -> str: + async def _cleanup_expired_tokens(self): + """清理过期的令牌""" + now = time.time() + expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now] + for token in expired_tokens: + self.staged_files.pop(token, None) + + async def register_file(self, file_path: str, timeout: float = None) -> str: """向令牌服务注册一个文件。 Args: file_path(str): 文件路径 + timeout(float): 超时时间,单位秒(可选) Returns: str: 一个单次令牌 Raises: - FileNotFoundError: 当路径不存在时抛出。 + FileNotFoundError: 当路径不存在时抛出 """ async with self.lock: + await self._cleanup_expired_tokens() + if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") + file_token = str(uuid.uuid4()) - self.staged_files[file_token] = file_path + expire_time = time.time() + (timeout if timeout is not None else self.default_timeout) + self.staged_files[file_token] = (file_path, expire_time) return file_token async def handle_file(self, file_token: str) -> str: + """根据令牌获取文件路径,使用后令牌失效。 + + Args: + file_token(str): 注册时返回的令牌 + + Returns: + str: 文件路径 + + Raises: + KeyError: 当令牌不存在或已过期时抛出 + FileNotFoundError: 当文件本身已被删除时抛出 + """ async with self.lock: + await self._cleanup_expired_tokens() + if file_token not in self.staged_files: - raise KeyError(f"无效文件 token: {file_token}") - file_path = self.staged_files.pop(file_token, None) + raise KeyError(f"无效或过期的文件 token: {file_token}") + + file_path, _ = self.staged_files.pop(file_token) if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") return file_path