Compare commits
1 Commits
master
...
refactor-2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a107921bc9 |
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from astrbot.core import logger, sp
|
from astrbot.core import astrbot_config, logger, sp
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
|
|
||||||
@@ -24,6 +24,7 @@ class ProviderManager:
|
|||||||
db_helper: BaseDatabase,
|
db_helper: BaseDatabase,
|
||||||
persona_mgr: PersonaManager,
|
persona_mgr: PersonaManager,
|
||||||
):
|
):
|
||||||
|
self.reload_lock = asyncio.Lock()
|
||||||
self.persona_mgr = persona_mgr
|
self.persona_mgr = persona_mgr
|
||||||
self.acm = acm
|
self.acm = acm
|
||||||
config = acm.confs["default"]
|
config = acm.confs["default"]
|
||||||
@@ -226,6 +227,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
async def load_provider(self, provider_config: dict):
|
async def load_provider(self, provider_config: dict):
|
||||||
if not provider_config["enable"]:
|
if not provider_config["enable"]:
|
||||||
|
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||||||
return
|
return
|
||||||
if provider_config.get("provider_type", "") == "agent_runner":
|
if provider_config.get("provider_type", "") == "agent_runner":
|
||||||
return
|
return
|
||||||
@@ -434,40 +436,46 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def reload(self, provider_config: dict):
|
async def reload(self, provider_config: dict):
|
||||||
await self.terminate_provider(provider_config["id"])
|
async with self.reload_lock:
|
||||||
if provider_config["enable"]:
|
await self.terminate_provider(provider_config["id"])
|
||||||
await self.load_provider(provider_config)
|
if provider_config["enable"]:
|
||||||
|
await self.load_provider(provider_config)
|
||||||
|
|
||||||
# 和配置文件保持同步
|
# 和配置文件保持同步
|
||||||
config_ids = [provider["id"] for provider in self.providers_config]
|
self.providers_config = astrbot_config["provider"]
|
||||||
logger.debug(f"providers in user's config: {config_ids}")
|
config_ids = [provider["id"] for provider in self.providers_config]
|
||||||
for key in list(self.inst_map.keys()):
|
logger.info(f"providers in user's config: {config_ids}")
|
||||||
if key not in config_ids:
|
for key in list(self.inst_map.keys()):
|
||||||
await self.terminate_provider(key)
|
if key not in config_ids:
|
||||||
|
await self.terminate_provider(key)
|
||||||
|
|
||||||
if len(self.provider_insts) == 0:
|
if len(self.provider_insts) == 0:
|
||||||
self.curr_provider_inst = None
|
self.curr_provider_inst = None
|
||||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||||
self.curr_provider_inst = self.provider_insts[0]
|
self.curr_provider_inst = self.provider_insts[0]
|
||||||
logger.info(
|
logger.info(
|
||||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.stt_provider_insts) == 0:
|
if len(self.stt_provider_insts) == 0:
|
||||||
self.curr_stt_provider_inst = None
|
self.curr_stt_provider_inst = None
|
||||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
elif (
|
||||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
||||||
logger.info(
|
):
|
||||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||||
)
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.tts_provider_insts) == 0:
|
if len(self.tts_provider_insts) == 0:
|
||||||
self.curr_tts_provider_inst = None
|
self.curr_tts_provider_inst = None
|
||||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
elif (
|
||||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
||||||
logger.info(
|
):
|
||||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||||
)
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
||||||
|
)
|
||||||
|
|
||||||
def get_insts(self):
|
def get_insts(self):
|
||||||
return self.provider_insts
|
return self.provider_insts
|
||||||
|
|||||||
Reference in New Issue
Block a user