perf: 为不支持流式输出的平台提供fallback。

This commit is contained in:
Raven95676
2025-04-13 02:21:42 +08:00
parent fc146d3d00
commit b6963c1bf9
5 changed files with 128 additions and 52 deletions

View File

@@ -1,5 +1,7 @@
import asyncio import asyncio
import typing from typing import AsyncGenerator
import re
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import Group, MessageMember from astrbot.api.platform import Group, MessageMember
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
@@ -82,17 +84,30 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await super().send(message) await super().send(message)
async def send_streaming(self, generator): async def send_streaming(self, generator: AsyncGenerator):
buffer = None buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator: async for chain in generator:
if not buffer: if isinstance(chain, MessageChain):
buffer = chain for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else: else:
buffer.chain.extend(chain.chain) await self.send(MessageChain(chain=[comp]))
if not buffer:
return if buffer.strip():
buffer.squash_plain() await self.send(MessageChain([Plain(buffer)]))
await self.send(buffer)
return await super().send_streaming(generator) return await super().send_streaming(generator)
async def get_group(self, group_id=None, **kwargs): async def get_group(self, group_id=None, **kwargs):

View File

@@ -1,8 +1,12 @@
import asyncio import asyncio
import re
from typing import AsyncGenerator
import dingtalk_stream import dingtalk_stream
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot import logger from astrbot import logger
from astrbot.core.message.components import Plain
class DingtalkMessageEvent(AstrMessageEvent): class DingtalkMessageEvent(AstrMessageEvent):
@@ -61,15 +65,27 @@ class DingtalkMessageEvent(AstrMessageEvent):
await self.send_with_client(self.client, message) await self.send_with_client(self.client, message)
await super().send(message) await super().send(message)
async def send_streaming(self, generator): async def send_streaming(self, generator: AsyncGenerator):
buffer = None buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator: async for chain in generator:
if not buffer: if isinstance(chain, MessageChain):
buffer = chain for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else: else:
buffer.chain.extend(chain.chain) await self.send(MessageChain(chain=[comp]))
if not buffer: if buffer.strip():
return await self.send(MessageChain([Plain(buffer)]))
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator) return await super().send_streaming(generator)

View File

@@ -1,7 +1,10 @@
import asyncio
import re
import wave import wave
import uuid import uuid
import traceback import traceback
import os import os
from typing import AsyncGenerator
from astrbot.core.utils.io import save_temp_img, download_file from astrbot.core.utils.io import save_temp_img, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
@@ -217,15 +220,27 @@ class GewechatPlatformEvent(AstrMessageEvent):
members=members, members=members,
) )
async def send_streaming(self, generator): async def send_streaming(self, generator: AsyncGenerator):
buffer = None buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator: async for chain in generator:
if not buffer: if isinstance(chain, MessageChain):
buffer = chain for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else: else:
buffer.chain.extend(chain.chain) await self.send(MessageChain(chain=[comp]))
if not buffer: if buffer.strip():
return await self.send(MessageChain([Plain(buffer)]))
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator) return await super().send_streaming(generator)

View File

@@ -1,7 +1,9 @@
import asyncio
import json import json
import re
import uuid import uuid
import lark_oapi as lark import lark_oapi as lark
from typing import List from typing import List, AsyncGenerator
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.api.message_components import Plain, Image as AstrBotImage, At
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -92,15 +94,27 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message) await super().send(message)
async def send_streaming(self, generator): async def send_streaming(self, generator: AsyncGenerator):
buffer = None buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator: async for chain in generator:
if not buffer: if isinstance(chain, MessageChain):
buffer = chain for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else: else:
buffer.chain.extend(chain.chain) await self.send(MessageChain(chain=[comp]))
if not buffer: if buffer.strip():
return await self.send(MessageChain([Plain(buffer)]))
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator) return await super().send_streaming(generator)

View File

@@ -1,4 +1,8 @@
import asyncio
import re
import uuid import uuid
from typing import AsyncGenerator
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record from astrbot.api.message_components import Plain, Image, Record
@@ -85,15 +89,27 @@ class WecomPlatformEvent(AstrMessageEvent):
await super().send(message) await super().send(message)
async def send_streaming(self, generator): async def send_streaming(self, generator: AsyncGenerator):
buffer = None buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator: async for chain in generator:
if not buffer: if isinstance(chain, MessageChain):
buffer = chain for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else: else:
buffer.chain.extend(chain.chain) await self.send(MessageChain(chain=[comp]))
if not buffer: if buffer.strip():
return await self.send(MessageChain([Plain(buffer)]))
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator) return await super().send_streaming(generator)