perf: 为不支持流式输出的平台提供fallback。

This commit is contained in:
Raven95676
2025-04-13 02:21:42 +08:00
parent fc146d3d00
commit b6963c1bf9
5 changed files with 128 additions and 52 deletions

View File

@@ -1,5 +1,7 @@
import asyncio
import typing
from typing import AsyncGenerator
import re
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import Group, MessageMember
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
@@ -82,17 +84,30 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async def send_streaming(self, generator: AsyncGenerator):
buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else:
await self.send(MessageChain(chain=[comp]))
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)
async def get_group(self, group_id=None, **kwargs):

View File

@@ -1,8 +1,12 @@
import asyncio
import re
from typing import AsyncGenerator
import dingtalk_stream
import astrbot.api.message_components as Comp
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot import logger
from astrbot.core.message.components import Plain
class DingtalkMessageEvent(AstrMessageEvent):
@@ -61,15 +65,27 @@ class DingtalkMessageEvent(AstrMessageEvent):
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async def send_streaming(self, generator: AsyncGenerator):
buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else:
await self.send(MessageChain(chain=[comp]))
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)

View File

@@ -1,7 +1,10 @@
import asyncio
import re
import wave
import uuid
import traceback
import os
from typing import AsyncGenerator
from astrbot.core.utils.io import save_temp_img, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
@@ -217,15 +220,27 @@ class GewechatPlatformEvent(AstrMessageEvent):
members=members,
)
async def send_streaming(self, generator):
buffer = None
async def send_streaming(self, generator: AsyncGenerator):
buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else:
await self.send(MessageChain(chain=[comp]))
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)

View File

@@ -1,7 +1,9 @@
import asyncio
import json
import re
import uuid
import lark_oapi as lark
from typing import List
from typing import List, AsyncGenerator
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
from astrbot.core.utils.io import download_image_by_url
@@ -92,15 +94,27 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async def send_streaming(self, generator: AsyncGenerator):
buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else:
await self.send(MessageChain(chain=[comp]))
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)

View File

@@ -1,4 +1,8 @@
import asyncio
import re
import uuid
from typing import AsyncGenerator
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
@@ -85,15 +89,27 @@ class WecomPlatformEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async def send_streaming(self, generator: AsyncGenerator):
buffer = ""
pattern = r"[^。?!~…]+[。?!~…]+"
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(0.5) # 限速
else:
await self.send(MessageChain(chain=[comp]))
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)