Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter a107921bc9 perf: enhance provider management with reload locking and logging
- Introduced a reload lock to prevent concurrent reloads of providers.
- Added logging to indicate when a provider is disabled and when providers are being synchronized with the configuration.
- Refactored the reload method to improve clarity and maintainability.
2025-11-27 15:30:53 +08:00
2 changed files with 40 additions and 34 deletions
+39 -31
View File
@@ -1,7 +1,7 @@
import asyncio import asyncio
import traceback import traceback
from astrbot.core import logger, sp from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
@@ -24,6 +24,7 @@ class ProviderManager:
db_helper: BaseDatabase, db_helper: BaseDatabase,
persona_mgr: PersonaManager, persona_mgr: PersonaManager,
): ):
self.reload_lock = asyncio.Lock()
self.persona_mgr = persona_mgr self.persona_mgr = persona_mgr
self.acm = acm self.acm = acm
config = acm.confs["default"] config = acm.confs["default"]
@@ -226,6 +227,7 @@ class ProviderManager:
async def load_provider(self, provider_config: dict): async def load_provider(self, provider_config: dict):
if not provider_config["enable"]: if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return return
if provider_config.get("provider_type", "") == "agent_runner": if provider_config.get("provider_type", "") == "agent_runner":
return return
@@ -434,40 +436,46 @@ class ProviderManager:
) )
async def reload(self, provider_config: dict): async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config["id"]) async with self.reload_lock:
if provider_config["enable"]: await self.terminate_provider(provider_config["id"])
await self.load_provider(provider_config) if provider_config["enable"]:
await self.load_provider(provider_config)
# 和配置文件保持同步 # 和配置文件保持同步
config_ids = [provider["id"] for provider in self.providers_config] self.providers_config = astrbot_config["provider"]
logger.debug(f"providers in user's config: {config_ids}") config_ids = [provider["id"] for provider in self.providers_config]
for key in list(self.inst_map.keys()): logger.info(f"providers in user's config: {config_ids}")
if key not in config_ids: for key in list(self.inst_map.keys()):
await self.terminate_provider(key) if key not in config_ids:
await self.terminate_provider(key)
if len(self.provider_insts) == 0: if len(self.provider_insts) == 0:
self.curr_provider_inst = None self.curr_provider_inst = None
elif self.curr_provider_inst is None and len(self.provider_insts) > 0: elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0] self.curr_provider_inst = self.provider_insts[0]
logger.info( logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。", f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
) )
if len(self.stt_provider_insts) == 0: if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None self.curr_stt_provider_inst = None
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: elif (
self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
logger.info( ):
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。", self.curr_stt_provider_inst = self.stt_provider_insts[0]
) logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
)
if len(self.tts_provider_insts) == 0: if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None self.curr_tts_provider_inst = None
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: elif (
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
logger.info( ):
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", self.curr_tts_provider_inst = self.tts_provider_insts[0]
) logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
)
def get_insts(self): def get_insts(self):
return self.provider_insts return self.provider_insts
+1 -3
View File
@@ -6,7 +6,7 @@ from collections import defaultdict
from astrbot import logger from astrbot import logger
from astrbot.api import star from astrbot.api import star
from astrbot.api.event import AstrMessageEvent from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import At, Image, Plain from astrbot.api.message_components import Image, Plain
from astrbot.api.platform import MessageType from astrbot.api.platform import MessageType
from astrbot.api.provider import Provider, ProviderRequest from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -142,8 +142,6 @@ class LongTermMemory:
logger.error(f"获取图片描述失败: {e}") logger.error(f"获取图片描述失败: {e}")
else: else:
parts.append(" [Image]") parts.append(" [Image]")
elif isinstance(comp, At):
parts.append(f" [At: {comp.name}]")
final_message = "".join(parts) final_message = "".join(parts)
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")