perf: 为不支持流式输出的平台提供fallback。
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from typing import AsyncGenerator
|
||||
import re
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
||||
@@ -82,17 +84,30 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
buffer = ""
|
||||
pattern = r"[^。?!~…]+[。?!~…]+"
|
||||
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(0.5) # 限速
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import dingtalk_stream
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Plain
|
||||
|
||||
|
||||
class DingtalkMessageEvent(AstrMessageEvent):
|
||||
@@ -61,15 +65,27 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
buffer = ""
|
||||
pattern = r"[^。?!~…]+[。?!~…]+"
|
||||
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(0.5) # 限速
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import re
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
@@ -217,15 +220,27 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
buffer = ""
|
||||
pattern = r"[^。?!~…]+[。?!~…]+"
|
||||
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(0.5) # 限速
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import lark_oapi as lark
|
||||
from typing import List
|
||||
from typing import List, AsyncGenerator
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
@@ -92,15 +94,27 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
buffer = ""
|
||||
pattern = r"[^。?!~…]+[。?!~…]+"
|
||||
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(0.5) # 限速
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
@@ -85,15 +89,27 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
buffer = ""
|
||||
pattern = r"[^。?!~…]+[。?!~…]+"
|
||||
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(0.5) # 限速
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
Reference in New Issue
Block a user