From 19c9177d7b16ad8a865d522ccba6aa5d3204c71a Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Sun, 13 Apr 2025 17:03:06 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E7=A7=BB=E9=99=A4=E5=AF=B9dingtalk?= =?UTF-8?q?=E3=80=81lark=E3=80=81wecom=E7=9A=84fallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sources/dingtalk/dingtalk_event.py | 40 +++++-------------- .../core/platform/sources/lark/lark_event.py | 31 +++++--------- .../platform/sources/wecom/wecom_event.py | 40 +++++-------------- 3 files changed, 31 insertions(+), 80 deletions(-) diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index 5c981d5d..d850a759 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -1,12 +1,8 @@ import asyncio -import re -from typing import AsyncGenerator - import dingtalk_stream import astrbot.api.message_components as Comp from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot import logger -from astrbot.core.message.components import Plain class DingtalkMessageEvent(AstrMessageEvent): @@ -65,31 +61,15 @@ class DingtalkMessageEvent(AstrMessageEvent): await self.send_with_client(self.client, message) await super().send(message) - async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str: - while True: - match = re.search(pattern, buffer) - if not match: - break - matched_text = match.group() - await self.send(MessageChain([Plain(matched_text)])) - buffer = buffer[match.end() :] - await asyncio.sleep(0.5) # 限速 - return buffer - - async def send_streaming(self, generator: AsyncGenerator): - buffer = "" - pattern = re.compile(r"[^。?!~…]+[。?!~…]+") - + async def send_streaming(self, generator): + buffer = None async for chain in generator: - if isinstance(chain, MessageChain): - for comp in chain.chain: - if isinstance(comp, Plain): - buffer += comp.text - if any(p in buffer for p in "。?!~…"): - buffer = await self.process_buffer(buffer, pattern) - else: - await self.send(MessageChain(chain=[comp])) - - if buffer.strip(): - await self.send(MessageChain([Plain(buffer)])) + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 48c7c8d1..b1aee548 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,12 +1,9 @@ -import asyncio import json -import re import uuid import base64 import lark_oapi as lark - from io import BytesIO -from typing import List, AsyncGenerator +from typing import List from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.core.utils.io import download_image_by_url @@ -107,21 +104,15 @@ class LarkMessageEvent(AstrMessageEvent): await super().send(message) - async def send_streaming(self, generator: AsyncGenerator): - buffer = "" - pattern = re.compile(r"[^。?!~…]+[。?!~…]+") - + async def send_streaming(self, generator): + buffer = None async for chain in generator: - if isinstance(chain, MessageChain): - for comp in chain.chain: - if isinstance(comp, Plain): - buffer += comp.text - if any(p in buffer for p in "。?!~…"): - buffer = await self.process_buffer(buffer, pattern) - else: - await self.send(MessageChain(chain=[comp])) - await asyncio.sleep(0.8) # 限速 - - if buffer.strip(): - await self.send(MessageChain([Plain(buffer)])) + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 2f53e151..d8ee8b9a 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -1,8 +1,4 @@ -import asyncio -import re import uuid -from typing import AsyncGenerator - from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image, Record @@ -89,31 +85,15 @@ class WecomPlatformEvent(AstrMessageEvent): await super().send(message) - async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str: - while True: - match = re.search(pattern, buffer) - if not match: - break - matched_text = match.group() - await self.send(MessageChain([Plain(matched_text)])) - buffer = buffer[match.end() :] - await asyncio.sleep(0.5) # 限速 - return buffer - - async def send_streaming(self, generator: AsyncGenerator): - buffer = "" - pattern = re.compile(r"[^。?!~…]+[。?!~…]+") - + async def send_streaming(self, generator): + buffer = None async for chain in generator: - if isinstance(chain, MessageChain): - for comp in chain.chain: - if isinstance(comp, Plain): - buffer += comp.text - if any(p in buffer for p in "。?!~…"): - buffer = await self.process_buffer(buffer, pattern) - else: - await self.send(MessageChain(chain=[comp])) - - if buffer.strip(): - await self.send(MessageChain([Plain(buffer)])) + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) return await super().send_streaming(generator)