From 9b36a5c8a666918d9d7a249cc59a76804814ac15 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 6 Apr 2025 13:43:23 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20=E5=A2=9E=E5=8A=A0=E5=85=A8?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0=E5=AF=B9=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E7=9A=84=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/message/message_event_result.py | 17 ++++ .../process_stage/method/llm_request.py | 1 - .../aiocqhttp/aiocqhttp_message_event.py | 13 +++ .../sources/dingtalk/dingtalk_event.py | 19 ++++- .../sources/gewechat/gewechat_event.py | 13 +++ .../core/platform/sources/lark/lark_event.py | 13 +++ .../qqofficial/qqofficial_message_event.py | 82 +++++++++++++++++-- .../qqofficial_webhook/qo_webhook_adapter.py | 5 +- .../platform/sources/wecom/wecom_event.py | 13 +++ astrbot/dashboard/routes/config.py | 2 +- 10 files changed, 167 insertions(+), 11 deletions(-) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 0f7c4c7a..0e35e93f 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -111,6 +111,23 @@ class MessageChain: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + def squash_plain(self): + """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" + if not self.chain: + return + first_plain = None + to_delete = [] + for i, comp in enumerate(self.chain): + if isinstance(comp, Plain): + if first_plain is None: + first_plain = i + else: + self.chain[first_plain].text += comp.text + to_delete.append(i) + for i in reversed(to_delete): + self.chain.pop(i) + return self + class EventResultType(enum.Enum): """用于描述事件处理的结果类型。 diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fafb8194..7b4403a7 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -156,7 +156,6 @@ class LLMRequestSubStage(Stage): ) async for llm_response in stream: if llm_response.is_chunk: - logger.debug(llm_response) if llm_response.result_chain: yield llm_response.result_chain # MessageChain else: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 295014ab..9bb8b938 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -82,6 +82,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await super().send(message) + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) + async def get_group(self, group_id=None, **kwargs): if isinstance(group_id, str) and group_id.isdigit(): group_id = int(group_id) diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index 7980ecd5..d850a759 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -24,7 +24,11 @@ class DingtalkMessageEvent(AstrMessageEvent): if isinstance(segment, Comp.Plain): segment.text = segment.text.strip() await asyncio.get_event_loop().run_in_executor( - None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message + None, + client.reply_markdown, + "AstrBot", + segment.text, + self.message_obj.raw_message, ) elif isinstance(segment, Comp.Image): markdown_str = "" @@ -56,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent): async def send(self, message: MessageChain): await self.send_with_client(self.client, message) await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 78902a4c..829a348c 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -216,3 +216,16 @@ class GewechatPlatformEvent(AstrMessageEvent): group_owner=data.get("chatRoomOwner"), members=members, ) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index e170b76a..544a7a5b 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -91,3 +91,16 @@ class LarkMessageEvent(AstrMessageEvent): logger.error(f"回复飞书消息失败({response.code}): {response.msg}") await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index d3100661..7a6183a2 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -2,6 +2,7 @@ import botpy import botpy.message import botpy.types import botpy.types.message +import asyncio from astrbot.core.utils.io import file_to_base64, download_image_by_url from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata @@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image from botpy import Client from botpy.http import Route from astrbot.api import logger +from botpy.types import message class QQOfficialMessageEvent(AstrMessageEvent): @@ -30,8 +32,41 @@ class QQOfficialMessageEvent(AstrMessageEvent): else: self.send_buffer.chain.extend(message.chain) - async def _post_send(self): + async def send_streaming(self, generator): + """流式输出仅支持消息列表私聊""" + stream_payload = {"state": 1, "id": None, "index": 0, "reset": False} + last_edit_time = 0 # 上次编辑消息的时间 + throttle_interval = 1 # 编辑消息的间隔时间 (秒) + async for chain in generator: + source = self.message_obj.raw_message + if not self.send_buffer: + self.send_buffer = chain + else: + self.send_buffer.chain.extend(chain.chain) + + if isinstance(source, botpy.message.C2CMessage): + # 真流式传输 + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + if time_since_last_edit >= throttle_interval: + ret = await self._post_send(stream=stream_payload) + stream_payload["index"] += 1 + stream_payload["id"] = ret["id"] + last_edit_time = asyncio.get_event_loop().time() + + if isinstance(source, botpy.message.C2CMessage): + # 结束流式对话,并且传输 buffer 中剩余的消息 + stream_payload["state"] = 10 + ret = await self._post_send(stream=stream_payload) + + return await super().send_streaming(generator) + + async def _post_send(self, stream: dict = None): """QQ 官方 API 仅支持回复一次""" + if not self.send_buffer: + return + source = self.message_obj.raw_message assert isinstance( source, @@ -65,7 +100,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - await self.bot.api.post_group_message( + ret = await self.bot.api.post_group_message( group_openid=source.group_openid, **payload ) case botpy.message.C2CMessage: @@ -75,22 +110,34 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - await self.bot.api.post_c2c_message( - openid=source.author.user_openid, **payload - ) + if stream: + ret = await self.post_c2c_message( + openid=source.author.user_openid, + **payload, + stream=stream, + ) + else: + ret = await self.post_c2c_message( + openid=source.author.user_openid, **payload + ) + logger.debug(f"Message sent to C2C: {ret}") case botpy.message.Message: if image_path: payload["file_image"] = image_path - await self.bot.api.post_message(channel_id=source.channel_id, **payload) + ret = await self.bot.api.post_message( + channel_id=source.channel_id, **payload + ) case botpy.message.DirectMessage: if image_path: payload["file_image"] = image_path - await self.bot.api.post_dms(guild_id=source.guild_id, **payload) + ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) await super().send(self.send_buffer) self.send_buffer = None + return ret + async def upload_group_and_c2c_image( self, image_base64: str, file_type: int, **kwargs ) -> botpy.types.message.Media: @@ -112,6 +159,27 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) return await self.bot.api._http.request(route, json=payload) + async def post_c2c_message( + self, + openid: str, + msg_type: int = 0, + content: str = None, + embed: message.Embed = None, + ark: message.Ark = None, + message_reference: message.Reference = None, + media: message.Media = None, + msg_id: str = None, + msg_seq: str = 1, + event_id: str = None, + markdown: message.MarkdownPayload = None, + keyboard: message.Keyboard = None, + stream: dict = None, + ) -> message.Message: + payload = locals() + payload.pop("self", None) + route = Route("POST", "/v2/users/{openid}/messages", openid=openid) + return await self.bot.api._http.request(route, json=payload) + @staticmethod async def _parse_to_qqofficial(message: MessageChain): plain_text = "" diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 226a1276..ede09e7f 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -116,5 +116,8 @@ class QQOfficialWebhookPlatformAdapter(Platform): async def terminate(self): self.webhook_helper.shutdown_event.set() await self.client.close() - await self.webhook_helper.server.shutdown() + try: + await self.webhook_helper.server.shutdown() + except Exception as _: + pass logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 470b7b1f..d8ee8b9a 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -84,3 +84,16 @@ class WecomPlatformEvent(AstrMessageEvent): ) await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 629a424f..2747865e 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -179,7 +179,7 @@ class ConfigRoute(Route): await self._save_astrbot_configs(post_configs) return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ except Exception as e: - logger.error(e) + logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ async def post_plugin_configs(self):