✨ feat: 增加全平台对流式输出的处理逻辑
This commit is contained in:
@@ -111,6 +111,23 @@ class MessageChain:
|
||||
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
|
||||
def squash_plain(self):
|
||||
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||
if not self.chain:
|
||||
return
|
||||
first_plain = None
|
||||
to_delete = []
|
||||
for i, comp in enumerate(self.chain):
|
||||
if isinstance(comp, Plain):
|
||||
if first_plain is None:
|
||||
first_plain = i
|
||||
else:
|
||||
self.chain[first_plain].text += comp.text
|
||||
to_delete.append(i)
|
||||
for i in reversed(to_delete):
|
||||
self.chain.pop(i)
|
||||
return self
|
||||
|
||||
|
||||
class EventResultType(enum.Enum):
|
||||
"""用于描述事件处理的结果类型。
|
||||
|
||||
@@ -156,7 +156,6 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
logger.debug(llm_response)
|
||||
if llm_response.result_chain:
|
||||
yield llm_response.result_chain # MessageChain
|
||||
else:
|
||||
|
||||
@@ -82,6 +82,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if isinstance(group_id, str) and group_id.isdigit():
|
||||
group_id = int(group_id)
|
||||
|
||||
@@ -24,7 +24,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, Comp.Plain):
|
||||
segment.text = segment.text.strip()
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"AstrBot",
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
@@ -56,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -216,3 +216,16 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
group_owner=data.get("chatRoomOwner"),
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -91,3 +91,16 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -2,6 +2,7 @@ import botpy
|
||||
import botpy.message
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
import asyncio
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
@@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
from botpy.types import message
|
||||
|
||||
|
||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
@@ -30,8 +32,41 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def _post_send(self):
|
||||
async def send_streaming(self, generator):
|
||||
"""流式输出仅支持消息列表私聊"""
|
||||
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||
async for chain in generator:
|
||||
source = self.message_obj.raw_message
|
||||
if not self.send_buffer:
|
||||
self.send_buffer = chain
|
||||
else:
|
||||
self.send_buffer.chain.extend(chain.chain)
|
||||
|
||||
if isinstance(source, botpy.message.C2CMessage):
|
||||
# 真流式传输
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
stream_payload["index"] += 1
|
||||
stream_payload["id"] = ret["id"]
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
|
||||
if isinstance(source, botpy.message.C2CMessage):
|
||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||
stream_payload["state"] = 10
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
async def _post_send(self, stream: dict = None):
|
||||
"""QQ 官方 API 仅支持回复一次"""
|
||||
if not self.send_buffer:
|
||||
return
|
||||
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(
|
||||
source,
|
||||
@@ -65,7 +100,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
await self.bot.api.post_group_message(
|
||||
ret = await self.bot.api.post_group_message(
|
||||
group_openid=source.group_openid, **payload
|
||||
)
|
||||
case botpy.message.C2CMessage:
|
||||
@@ -75,22 +110,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
await self.bot.api.post_c2c_message(
|
||||
openid=source.author.user_openid, **payload
|
||||
)
|
||||
if stream:
|
||||
ret = await self.post_c2c_message(
|
||||
openid=source.author.user_openid,
|
||||
**payload,
|
||||
stream=stream,
|
||||
)
|
||||
else:
|
||||
ret = await self.post_c2c_message(
|
||||
openid=source.author.user_openid, **payload
|
||||
)
|
||||
logger.debug(f"Message sent to C2C: {ret}")
|
||||
case botpy.message.Message:
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
|
||||
ret = await self.bot.api.post_message(
|
||||
channel_id=source.channel_id, **payload
|
||||
)
|
||||
case botpy.message.DirectMessage:
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
|
||||
return ret
|
||||
|
||||
async def upload_group_and_c2c_image(
|
||||
self, image_base64: str, file_type: int, **kwargs
|
||||
) -> botpy.types.message.Media:
|
||||
@@ -112,6 +159,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
async def post_c2c_message(
|
||||
self,
|
||||
openid: str,
|
||||
msg_type: int = 0,
|
||||
content: str = None,
|
||||
embed: message.Embed = None,
|
||||
ark: message.Ark = None,
|
||||
message_reference: message.Reference = None,
|
||||
media: message.Media = None,
|
||||
msg_id: str = None,
|
||||
msg_seq: str = 1,
|
||||
event_id: str = None,
|
||||
markdown: message.MarkdownPayload = None,
|
||||
keyboard: message.Keyboard = None,
|
||||
stream: dict = None,
|
||||
) -> message.Message:
|
||||
payload = locals()
|
||||
payload.pop("self", None)
|
||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
@staticmethod
|
||||
async def _parse_to_qqofficial(message: MessageChain):
|
||||
plain_text = ""
|
||||
|
||||
@@ -116,5 +116,8 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
async def terminate(self):
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
await self.client.close()
|
||||
await self.webhook_helper.server.shutdown()
|
||||
try:
|
||||
await self.webhook_helper.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||
|
||||
@@ -84,3 +84,16 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -179,7 +179,7 @@ class ConfigRoute(Route):
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def post_plugin_configs(self):
|
||||
|
||||
Reference in New Issue
Block a user