Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
a107921bc9 perf: enhance provider management with reload locking and logging
- Introduced a reload lock to prevent concurrent reloads of providers.
- Added logging to indicate when a provider is disabled and when providers are being synchronized with the configuration.
- Refactored the reload method to improve clarity and maintainability.
2025-11-27 15:30:53 +08:00

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import traceback import traceback
from astrbot.core import logger, sp from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
@@ -24,6 +24,7 @@ class ProviderManager:
db_helper: BaseDatabase, db_helper: BaseDatabase,
persona_mgr: PersonaManager, persona_mgr: PersonaManager,
): ):
self.reload_lock = asyncio.Lock()
self.persona_mgr = persona_mgr self.persona_mgr = persona_mgr
self.acm = acm self.acm = acm
config = acm.confs["default"] config = acm.confs["default"]
@@ -226,6 +227,7 @@ class ProviderManager:
async def load_provider(self, provider_config: dict): async def load_provider(self, provider_config: dict):
if not provider_config["enable"]: if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return return
if provider_config.get("provider_type", "") == "agent_runner": if provider_config.get("provider_type", "") == "agent_runner":
return return
@@ -434,40 +436,46 @@ class ProviderManager:
) )
async def reload(self, provider_config: dict): async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config["id"]) async with self.reload_lock:
if provider_config["enable"]: await self.terminate_provider(provider_config["id"])
await self.load_provider(provider_config) if provider_config["enable"]:
await self.load_provider(provider_config)
# 和配置文件保持同步 # 和配置文件保持同步
config_ids = [provider["id"] for provider in self.providers_config] self.providers_config = astrbot_config["provider"]
logger.debug(f"providers in user's config: {config_ids}") config_ids = [provider["id"] for provider in self.providers_config]
for key in list(self.inst_map.keys()): logger.info(f"providers in user's config: {config_ids}")
if key not in config_ids: for key in list(self.inst_map.keys()):
await self.terminate_provider(key) if key not in config_ids:
await self.terminate_provider(key)
if len(self.provider_insts) == 0: if len(self.provider_insts) == 0:
self.curr_provider_inst = None self.curr_provider_inst = None
elif self.curr_provider_inst is None and len(self.provider_insts) > 0: elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0] self.curr_provider_inst = self.provider_insts[0]
logger.info( logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。", f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
) )
if len(self.stt_provider_insts) == 0: if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None self.curr_stt_provider_inst = None
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: elif (
self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
logger.info( ):
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。", self.curr_stt_provider_inst = self.stt_provider_insts[0]
) logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
)
if len(self.tts_provider_insts) == 0: if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None self.curr_tts_provider_inst = None
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: elif (
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
logger.info( ):
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", self.curr_tts_provider_inst = self.tts_provider_insts[0]
) logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
)
def get_insts(self): def get_insts(self):
return self.provider_insts return self.provider_insts