From 0f9ab082abed5dafc507089c8e6e254867848cb8 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 11 Jan 2025 19:45:42 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96webchat=EF=BC=8C?= =?UTF-8?q?=E6=B2=A1=E6=9C=89=E7=BB=93=E6=9E=9C=E8=BF=94=E5=9B=9E=E6=97=B6?= =?UTF-8?q?=E7=9A=84=E5=8F=8D=E9=A6=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/__init__.py | 4 ++-- astrbot/core/pipeline/scheduler.py | 4 ++++ astrbot/core/platform/sources/webchat/webchat_event.py | 10 +++++++--- astrbot/core/provider/manager.py | 10 ++++++---- astrbot/dashboard/routes/chat.py | 6 +++++- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 0ef8e039..aac8fc11 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -20,6 +20,6 @@ if os.environ.get('TESTING', ""): db_helper = SQLiteDatabase(DB_PATH) sp = SharedPreferences() # 简单的偏好设置存储 pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', '')) -web_chat_queue = asyncio.Queue() -web_chat_back_queue = asyncio.Queue() +web_chat_queue = asyncio.Queue(maxsize=32) +web_chat_back_queue = asyncio.Queue(maxsize=32) WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 8842057b..e18ee92b 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -41,4 +41,8 @@ class PipelineScheduler(): async def execute(self, event: AstrMessageEvent): '''执行 pipeline''' await self._process_stages(event) + + if not event._has_send_oper and event.get_platform_name() == "webchat": + await event.send(None) + logger.debug("pipeline 执行完毕。") \ No newline at end of file diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index c988724b..4312b0cb 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -12,9 +12,13 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(self.imgs_dir, exist_ok=True) async def send(self, message: MessageChain): + if not message: + await web_chat_back_queue.put_nowait(None) + return + for comp in message.chain: if isinstance(comp, Plain): - await web_chat_back_queue.put(comp.text) + await web_chat_back_queue.put_nowait(comp.text) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -26,6 +30,6 @@ class WebChatMessageEvent(AstrMessageEvent): f.write(f2.read()) elif comp.file and comp.file.startswith("http"): await download_image_by_url(comp.file, path=path) - await web_chat_back_queue.put(f"[IMAGE]{filename}") - await web_chat_back_queue.put(None) + await web_chat_back_queue.put_nowait(f"[IMAGE]{filename}") + await web_chat_back_queue.put_nowait(None) await super().send(message) \ No newline at end of file diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 4f70c33f..7523960e 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -64,6 +64,8 @@ class ProviderManager(): continue selected_provider_id = sp.get("curr_provider") selected_stt_provider_id = self.provider_stt_settings.get("provider_id") + provider_enabled = self.provider_settings.get("enable", False) + stt_enabled = self.provider_stt_settings.get("enable", False) provider_metadata = provider_cls_map[provider_config['type']] logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") @@ -74,7 +76,7 @@ class ProviderManager(): # STT 任务 inst = provider_metadata.cls_type(provider_config, self.provider_settings) self.stt_provider_insts.append(inst) - if selected_stt_provider_id == provider_config['id']: + if selected_stt_provider_id == provider_config['id'] and stt_enabled: self.curr_stt_provider_inst = inst logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") @@ -82,7 +84,7 @@ class ProviderManager(): # 文本生成任务 inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) self.provider_insts.append(inst) - if selected_provider_id == provider_config['id']: + if selected_provider_id == provider_config['id'] and provider_enabled: self.curr_provider_inst = inst logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") @@ -90,10 +92,10 @@ class ProviderManager(): traceback.print_exc() logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") - if len(self.provider_insts) > 0 and not self.curr_provider_inst: + if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled: self.curr_provider_inst = self.provider_insts[0] - if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst: + if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled: self.curr_stt_provider_inst = self.stt_provider_insts[0] if not self.curr_provider_inst: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 45bcb544..639c7caf 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -116,7 +116,11 @@ class ChatRoute(Route): async def stream(): ret = [] while True: - result = await web_chat_back_queue.get() + try: + result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒 + except asyncio.TimeoutError: + yield '[Error] 30 秒内没有返回数据,已放弃。\n' + return if result is None: break