🍺 refactor: 支持更大范围的热重载以及管理面板将平台和提供商配置独立化

This commit is contained in:
Soulter
2025-02-23 12:54:25 +08:00
parent cf22eae467
commit da14a89490
15 changed files with 844 additions and 192 deletions

View File

@@ -154,7 +154,8 @@ CONFIG_METADATA_2 = {
"id": { "id": {
"description": "ID", "description": "ID",
"type": "string", "type": "string",
"hint": "用于在多实例下方便管理和识别。自定义ID 不能重复。", "obvious_hint": True,
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
}, },
"type": { "type": {
"description": "适配器类型", "description": "适配器类型",
@@ -630,7 +631,8 @@ CONFIG_METADATA_2 = {
"id": { "id": {
"description": "ID", "description": "ID",
"type": "string", "type": "string",
"hint": "提供商 ID 名用于在多实例下方便管理和识别。自定义ID 不能重复。", "obvious_hint": True,
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
}, },
"type": { "type": {
"description": "模型提供商类型", "description": "模型提供商类型",

View File

@@ -63,9 +63,6 @@ class AstrBotCoreLifecycle:
await self.provider_manager.initialize() await self.provider_manager.initialize()
'''根据配置实例化各个 Provider''' '''根据配置实例化各个 Provider'''
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager)) self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
await self.pipeline_scheduler.initialize() await self.pipeline_scheduler.initialize()
'''初始化消息事件流水线调度器''' '''初始化消息事件流水线调度器'''
@@ -74,19 +71,18 @@ class AstrBotCoreLifecycle:
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
self.start_time = int(time.time()) self.start_time = int(time.time())
self.curr_tasks: List[asyncio.Task] = [] self.curr_tasks: List[asyncio.Task] = []
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
def _load(self): def _load(self):
platform_tasks = self.load_platform()
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus") event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
extra_tasks = [] extra_tasks = []
for task in self.star_context._register_tasks: for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) extra_tasks.append(asyncio.create_task(task, name=task.__name__))
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks] tasks_ = [event_bus_task, *extra_tasks]
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
for task in tasks_: for task in tasks_:
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name())) self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))

View File

@@ -1,3 +1,5 @@
import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from .platform import Platform from .platform import Platform
from typing import List from typing import List
@@ -11,43 +13,102 @@ class PlatformManager():
self.platform_insts: List[Platform] = [] self.platform_insts: List[Platform] = []
'''加载的 Platform 的实例''' '''加载的 Platform 的实例'''
self._inst_map = {}
self.platforms_config = config['platform'] self.platforms_config = config['platform']
self.settings = config['platform_settings'] self.settings = config['platform_settings']
self.event_queue = event_queue self.event_queue = event_queue
try:
for platform in self.platforms_config:
if not platform['enable']:
continue
match platform['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
except Exception as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}")
async def initialize(self): async def initialize(self):
'''初始化所有平台适配器'''
for platform in self.platforms_config: for platform in self.platforms_config:
if not platform['enable']: await self.load_platform(platform)
continue
if platform['type'] not in platform_cls_map:
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = platform_cls_map[platform['type']]
logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue)) # 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
asyncio.create_task(self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")))
async def load_platform(self, platform_config: dict):
'''实例化一个平台'''
if not platform_config['enable']:
return
# 动态导入
try:
match platform_config['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
except Exception as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}")
if platform_config['type'] not in platform_cls_map:
logger.error(f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。")
return
cls_type = platform_cls_map[platform_config['type']]
logger.debug(f"尝试实例化 {platform_config['type']}({platform_config['id']}) 平台适配器 ...")
inst = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config['id']] = inst
self.platform_insts.append(inst)
asyncio.create_task(self._task_wrapper(asyncio.create_task(inst.run(), name=platform_config['id'] + "_platform")))
async def _task_wrapper(self, task: asyncio.Task):
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def reload(self, platform_config: dict):
# 还未实现完成,不要调用此方法
if platform_config['id'] in self._inst_map:
# 正在运行
if getattr(self._inst_map[platform_config['id']], 'terminate', None):
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
await self._inst_map[platform_config['id']].terminate()
logger.info(f"{platform_config['id']} 平台适配器已终止。")
del self._inst_map[platform_config['id']]
self.platform_insts.remove(self._inst_map[platform_config['id']])
else:
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
# 再启动新的实例
await self.load_platform(platform_config)
else:
# 先将 _inst_map 中在 platform_config 中不存在的实例删除
config_ids = [platform['id'] for platform in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
if getattr(self._inst_map[key], 'terminate', None):
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
await self._inst_map[key].terminate()
logger.info(f"{key} 平台适配器已终止。")
del self._inst_map[key]
self.platform_insts.remove(self._inst_map[key])
else:
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
# 再启动新的实例
await self.load_platform(platform_config)
def get_insts(self): def get_insts(self):
return self.platform_insts return self.platform_insts

View File

@@ -20,6 +20,12 @@ class Platform(abc.ABC):
''' '''
raise NotImplementedError raise NotImplementedError
async def terminate(self):
'''
终止一个平台的运行实例。
'''
pass
@abc.abstractmethod @abc.abstractmethod
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
''' '''

View File

@@ -32,6 +32,8 @@ class AiocqhttpAdapter(Platform):
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
) )
self.stop = False
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value: match session.message_type.value:
@@ -230,11 +232,15 @@ class AiocqhttpAdapter(Platform):
return bot return bot
async def terminate(self):
self.stop = True
await asyncio.sleep(1)
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return self.metadata return self.metadata
async def shutdown_trigger_placeholder(self): async def shutdown_trigger_placeholder(self):
while not self._event_queue.closed: while not self._event_queue.closed and not self.stop:
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。") logger.info("aiocqhttp 适配器已关闭。")

View File

@@ -54,6 +54,8 @@ class SimpleGewechatClient():
self.multimedia_downloader = None self.multimedia_downloader = None
self.userrealnames = {} self.userrealnames = {}
self.stop = False
async def get_token_id(self): async def get_token_id(self):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -231,7 +233,7 @@ class SimpleGewechatClient():
) )
async def shutdown_trigger_placeholder(self): async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed: while not self.event_queue.closed and not self.stop:
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。") logger.info("gewechat 适配器已关闭。")

View File

@@ -47,6 +47,10 @@ class GewechatPlatformAdapter(Platform):
"基于 gewechat 的 Wechat 适配器", "基于 gewechat 的 Wechat 适配器",
) )
async def terminate(self):
self.client.stop = True
await asyncio.sleep(1)
@override @override
def run(self): def run(self):
self.client = SimpleGewechatClient( self.client = SimpleGewechatClient(

View File

@@ -1,11 +1,9 @@
import traceback import traceback
import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType from .entites import ProviderType
from typing import List from typing import List
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from collections import defaultdict
from .register import provider_cls_map, llm_tools from .register import provider_cls_map, llm_tools
from astrbot.core import logger, sp from astrbot.core import logger, sp
@@ -16,6 +14,14 @@ class ProviderManager():
self.provider_stt_settings: dict = config.get('provider_stt_settings', {}) self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {}) self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', []) self.persona_configs: list = config.get('persona', [])
self.astrbot_config = config
self.selected_provider_id = sp.get("curr_provider")
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
self.provider_enabled = self.provider_settings.get("enable", False)
self.stt_enabled = self.provider_stt_settings.get("enable", False)
self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理 # 人格情景管理
# 目前没有拆成独立的模块 # 目前没有拆成独立的模块
@@ -75,14 +81,15 @@ class ProviderManager():
_mood_imitation_dialogs_processed="" _mood_imitation_dialogs_processed=""
) )
self.personas.append(self.selected_default_persona) self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = [] self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例''' '''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = [] self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例''' '''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = [] self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例''' '''加载的 Text To Speech Provider 的实例'''
self.inst_map = {}
'''Provider 实例映射. key: provider_id, value: Provider 实例'''
self.llm_tools = llm_tools self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例''' '''当前使用的 Provider 实例'''
@@ -90,7 +97,6 @@ class ProviderManager():
'''当前使用的 Speech To Text Provider 实例''' '''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例''' '''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper self.db_helper = db_helper
# kdb(experimental) # kdb(experimental)
@@ -99,145 +105,155 @@ class ProviderManager():
if kdb_cfg and len(kdb_cfg): if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0] self.curr_kdb_name = list(kdb_cfg.keys())[0]
changed = False
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}")
provider_cfg['id'] = new_id
changed = True
self.loaded_ids[provider_cfg['id']] = True
try:
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "anthropic_chat_completion":
from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope":
from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
if changed:
try:
config.save_config()
except Exception as e:
logger.warning(f"保存配置文件失败:{e}")
async def initialize(self): async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
for provider_config in self.providers_config: for provider_config in self.providers_config:
if not provider_config['enable']: await self.load_provider(provider_config)
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
provider_metadata = provider_cls_map[provider_config['type']]
logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get('persistant_history', True),
self.selected_default_persona
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst: if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。") logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst: if self.stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。") logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst: if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。") logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
async def load_provider(self, provider_config: dict):
if not provider_config['enable']:
return
# 动态导入
try:
match provider_config['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "anthropic_chat_completion":
from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope":
from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
return
except Exception as e:
logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因")
return
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
return
provider_metadata = provider_cls_map[provider_config['type']]
logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if self.selected_stt_provider_id == provider_config['id'] and self.stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
if not self.curr_stt_provider_inst and self.stt_enabled:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if self.selected_tts_provider_id == provider_config['id'] and self.tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
if not self.curr_tts_provider_inst and self.tts_enabled:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get('persistant_history', True),
self.selected_default_persona
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if self.selected_provider_id == provider_config['id'] and self.provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
if not self.curr_provider_inst and self.provider_enabled:
self.curr_provider_inst = inst
self.inst_map[provider_config['id']] = inst
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config['id'])
if provider_config['enable']:
await self.load_provider(provider_config)
# 和配置文件保持同步
config_ids = [provider['id'] for provider in self.providers_config]
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
def get_insts(self): def get_insts(self):
return self.provider_insts return self.provider_insts
async def terminate_provider(self, provider_id: str):
if provider_id in self.inst_map:
if self.inst_map[provider_id] in self.provider_insts:
self.provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] in self.stt_provider_insts:
self.stt_provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] in self.tts_provider_insts:
self.tts_provider_insts.remove(self.inst_map[provider_id])
if getattr(self.inst_map[provider_id], 'terminate', None):
logger.info(f"正在尝试终止 {provider_id} 提供商适配器 ...")
await self.inst_map[provider_id].terminate()
logger.info(f"{provider_id} 提供商适配器已终止。")
del self.inst_map[provider_id]
async def terminate(self): async def terminate(self):
for provider_inst in self.provider_insts: for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"): if hasattr(provider_inst, "terminate"):

View File

@@ -1,4 +1,5 @@
import typing import typing
import traceback
from .route import Route, Response, RouteContext from .route import Route, Response, RouteContext
from quart import request from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
@@ -77,6 +78,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
else: else:
errors, post_config = validate_config(post_config, config.schema, is_core) errors, post_config = validate_config(post_config, config.schema, is_core)
except BaseException as e: except BaseException as e:
logger.error(traceback.format_exc())
logger.warning(f"验证配置时出现异常: {e}") logger.warning(f"验证配置时出现异常: {e}")
if errors: if errors:
raise ValueError(f"格式校验未通过: {errors}") raise ValueError(f"格式校验未通过: {errors}")
@@ -90,6 +92,14 @@ class ConfigRoute(Route):
'/config/get': ('GET', self.get_configs), '/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs), '/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs), '/config/plugin/update': ('POST', self.post_plugin_configs),
'/config/platform/new': ('POST', self.post_new_platform),
'/config/platform/update': ('POST', self.post_update_platform),
'/config/platform/delete': ('POST', self.post_delete_platform),
'/config/provider/new': ('POST', self.post_new_provider),
'/config/provider/update': ('POST', self.post_update_provider),
'/config/provider/delete': ('POST', self.post_delete_provider)
} }
self.register_routes() self.register_routes()
@@ -118,7 +128,99 @@ class ConfigRoute(Route):
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__ return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
except Exception as e: except Exception as e:
return Response().error(str(e)).__dict__ return Response().error(str(e)).__dict__
async def post_new_platform(self):
new_platform_config = await request.json
self.config['platform'].append(new_platform_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.load_platform(new_platform_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增平台配置成功~").__dict__
async def post_new_provider(self):
new_provider_config = await request.json
self.config['provider'].append(new_provider_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.load_provider(new_provider_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增服务提供商配置成功~").__dict__
async def post_update_platform(self):
update_platform_config = await request.json
platform_id = update_platform_config.get("id", None)
new_config = update_platform_config.get("config", None)
if not platform_id or not new_config:
return Response().error("参数错误").__dict__
for i, platform in enumerate(self.config['platform']):
if platform['id'] == platform_id:
self.config['platform'][i] = new_config
break
else:
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新平台配置成功~").__dict__
async def post_update_provider(self):
update_provider_config = await request.json
provider_id = update_provider_config.get("id", None)
new_config = update_provider_config.get("config", None)
if not provider_id or not new_config:
return Response().error("参数错误").__dict__
for i, provider in enumerate(self.config['provider']):
if provider['id'] == provider_id:
self.config['provider'][i] = new_config
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.reload(new_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新成功,已经实时生效~").__dict__
async def post_delete_platform(self):
platform_id = await request.json
platform_id = platform_id.get("id")
for i, platform in enumerate(self.config['platform']):
if platform['id'] == platform_id:
del self.config['platform'][i]
break
else:
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除平台配置成功~").__dict__
async def post_delete_provider(self):
provider_id = await request.json
provider_id = provider_id.get("id")
for i, provider in enumerate(self.config['provider']):
if provider['id'] == provider_id:
del self.config['provider'][i]
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.terminate_provider(provider_id)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def _get_astrbot_config(self): async def _get_astrbot_config(self):
config = self.config config = self.config

View File

@@ -4,17 +4,6 @@
<v-card-title>正在等待 AstrBot 重启...</v-card-title> <v-card-title>正在等待 AstrBot 重启...</v-card-title>
<v-card-text> <v-card-text>
<v-progress-linear indeterminate color="primary"></v-progress-linear> <v-progress-linear indeterminate color="primary"></v-progress-linear>
<div style="margin-top: 16px;">
<div class="py-12 text-center" v-if="newStartTime != -1">
<v-icon class="mb-6" color="success" icon="mdi-check-circle-outline" size="128"></v-icon>
<p>重启成功</p>
</div>
<small v-if="startTime != -1" style="display: block;">当前实例标识{{ startTime }}</small>
<small v-if="newStartTime != -1" style="display: block;">检查到新实例{{ newStartTime }}即将自动刷新页面</small>
<small v-if="status" style="display: block;">{{ status }}</small>
<small style="display: block;">尝试次数{{ cnt }} / 60</small>
</div>
</v-card-text> </v-card-text>
</v-card> </v-card>
</v-dialog> </v-dialog>
@@ -73,11 +62,9 @@ export default {
if (this.newStartTime !== this.startTime) { if (this.newStartTime !== this.startTime) {
this.newStartTime = newStartTime this.newStartTime = newStartTime
console.log('wfr: restarted') console.log('wfr: restarted')
setTimeout(() => { this.visible = false
this.visible = false // reload
// reload window.location.reload()
window.location.reload()
}, 2000)
} }
return this.newStartTime return this.newStartTime
} }

View File

@@ -21,7 +21,17 @@ const sidebarItem: menu[] = [
to: '/dashboard/default' to: '/dashboard/default'
}, },
{ {
title: '配置文件', title: '消息平台',
icon: 'mdi-message-processing',
to: '/platforms',
},
{
title: '服务提供商',
icon: 'mdi-creation',
to: '/providers',
},
{
title: '配置',
icon: 'mdi-cog', icon: 'mdi-cog',
to: '/config', to: '/config',
}, },

View File

@@ -16,12 +16,21 @@ const MainRoutes = {
path: '/extension', path: '/extension',
component: () => import('@/views/ExtensionPage.vue') component: () => import('@/views/ExtensionPage.vue')
}, },
{
name: 'Platforms',
path: '/platforms',
component: () => import('@/views/PlatformPage.vue')
},
{
name: 'Providers',
path: '/providers',
component: () => import('@/views/ProviderPage.vue')
},
{ {
name: 'Configs', name: 'Configs',
path: '/config', path: '/config',
component: () => import('@/views/ConfigPage.vue') component: () => import('@/views/ConfigPage.vue')
}, },
{ {
name: 'Default', name: 'Default',
path: '/dashboard/default', path: '/dashboard/default',

View File

@@ -44,6 +44,11 @@ import config from '@/config';
</v-expansion-panel-title> </v-expansion-panel-title>
<v-expansion-panel-text v-if="metadata[key]['metadata'][key2]?.config_template"> <v-expansion-panel-text v-if="metadata[key]['metadata'][key2]?.config_template">
<!-- 带有 config_template 的配置项 --> <!-- 带有 config_template 的配置项 -->
<v-alert style="margin-top: 16px; margin-bottom: 16px" color="primary" variant="tonal">
消息平台适配器和服务提供商的配置已经迁移至更方便的独立页面推荐前往左栏配置哦
</v-alert>
<v-tabs style="margin-top: 16px;" align-tabs="left" color="deep-purple-accent-4" v-model="config_template_tab"> <v-tabs style="margin-top: 16px;" align-tabs="left" color="deep-purple-accent-4" v-model="config_template_tab">
<v-tab v-if="metadata[key]['metadata'][key2]?.tmpl_display_title" v-for="(item, index) in config_data[key2]" :key="index" :value="index"> <v-tab v-if="metadata[key]['metadata'][key2]?.tmpl_display_title" v-for="(item, index) in config_data[key2]" :key="index" :value="index">
{{ item[metadata[key]['metadata'][key2]?.tmpl_display_title] }} {{ item[metadata[key]['metadata'][key2]?.tmpl_display_title] }}

View File

@@ -0,0 +1,225 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 32px; height: 100%;">
<v-menu>
<template v-slot:activator="{ props }">
<v-btn class="flex-grow-1" variant="tonal" @click="new_platform_dialog = true" size="large"
rounded="lg" v-bind="props" color="primary">
<template v-slot:default>
<v-icon>mdi-plus</v-icon>
新增平台适配器
</template>
</v-btn>
</template>
<v-list @update:selected="addFromDefaultConfigTmpl($event)">
<v-list-item
v-for="(item, index) in metadata['platform_group']['metadata']['platform'].config_template"
:key="index" rounded="xl" :value="index">
<v-list-item-title>{{ index }}</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
<v-row style="margin-top: 16px;">
<v-col v-for="(platform, index) in config_data['platform']" :key="index" cols="12" md="6" lg="3">
<v-card class="fade-in" style="margin-bottom: 16px; min-height: 200px; display: flex; justify-content: space-between; flex-direction: column;">
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h4">{{ platform.id }}</span>
<v-switch color="primary" hide-details density="compact" v-model="platform['enable']"
@update:modelValue="platformStatusChange(platform)"></v-switch>
</v-card-title>
<v-card-text>
<div>
<span style="font-size:12px">适配器类型: </span>
<v-chip color="primary" text>{{ platform.type }}</v-chip>
</div>
</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn color="error" text @click="deletePlatform(platform.id);">
删除
</v-btn>
<v-btn color="blue-darken-1" text
@click="updatingMode = true; showPlatformCfg = true; newSelectedPlatformConfig = platform; newSelectedPlatformName = platform.id">
配置
</v-btn>
</v-card-actions>
</v-card>
</v-col>
</v-row>
<v-dialog v-model="showPlatformCfg" width="700">
<v-card>
<v-card-title>
<span class="text-h4">{{ newSelectedPlatformName }} 配置</span>
</v-card-title>
<v-card-text>
<AstrBotConfig :iterable="newSelectedPlatformConfig"
:metadata="metadata['platform_group']['metadata']" metadataKey="platform" />
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="newPlatform" :loading="loading">
保存
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</v-card-text>
</v-card>
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack">
{{ save_message }}
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script>
import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
export default {
name: 'PlatformPage',
components: {
AstrBotConfig,
WaitingForRestart
},
data() {
return {
config_data: {},
fetched: false,
metadata: {},
showPlatformCfg: false,
newSelectedPlatformName: '',
newSelectedPlatformConfig: {},
updatingMode: false,
loading: false,
save_message_snack: false,
save_message: "",
save_message_success: "",
}
},
mounted() {
this.getConfig();
},
methods: {
getConfig() {
// 获取配置
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
this.fetched = true
this.metadata = res.data.data.metadata;
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
addFromDefaultConfigTmpl(index) {
// 从默认配置模板中添加
console.log(index);
this.newSelectedPlatformName = index[0];
this.showPlatformCfg = true;
this.updatingMode = false;
this.newSelectedPlatformConfig = this.metadata['platform_group']['metadata']['platform'].config_template[index[0]];
},
newPlatform() {
// 新建或者更新平台
this.loading = true;
if (this.updatingMode) {
axios.post('/api/config/platform/update', {
id: this.newSelectedPlatformName,
config: this.newSelectedPlatformConfig
}).then((res) => {
this.loading = false;
this.showPlatformCfg = false;
this.getConfig();
this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
this.updatingMode = false;
} else {
axios.post('/api/config/platform/new', this.newSelectedPlatformConfig).then((res) => {
this.loading = false;
this.showPlatformCfg = false;
this.getConfig();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
},
deletePlatform(platform_id) {
// 删除平台
axios.post('/api/config/platform/delete', { id: platform_id }).then((res) => {
this.getConfig();
this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
platformStatusChange(platform) {
// 平台状态改变
axios.post('/api/config/platform/update', {
id: platform.id,
config: platform
}).then((res) => {
this.getConfig();
this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
</style>

View File

@@ -0,0 +1,221 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 32px; height: 100%;">
<v-menu>
<template v-slot:activator="{ props }">
<v-btn class="flex-grow-1" variant="tonal" @click="new_provider_dialog = true" size="large"
rounded="lg" v-bind="props" color="primary">
<template v-slot:default>
<v-icon>mdi-plus</v-icon>
新增服务提供商
</template>
</v-btn>
</template>
<v-list @update:selected="addFromDefaultConfigTmpl($event)">
<v-list-item
v-for="(item, index) in metadata['provider_group']['metadata']['provider'].config_template"
:key="index" rounded="xl" :value="index">
<v-list-item-title>{{ index }}</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
<v-row style="margin-top: 16px;">
<v-col v-for="(provider, index) in config_data['provider']" :key="index" cols="12" md="6" lg="3">
<v-card class="fade-in" style="margin-bottom: 16px; min-height: 200px; display: flex; justify-content: space-between; flex-direction: column;">
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h4">{{ provider.id }}</span>
<v-switch color="primary" hide-details density="compact" v-model="provider['enable']"
@update:modelValue="providerStatusChange(provider)"></v-switch>
</v-card-title>
<v-card-text>
<div>
<span style="font-size:12px">适配器类型: </span> <v-chip color="primary" text>{{ provider.type }}</v-chip>
</div>
</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn color="error" text @click="deleteprovider(provider.id);">
删除
</v-btn>
<v-btn color="blue-darken-1" text
@click="updatingMode = true; showproviderCfg = true; newSelectedproviderConfig = provider; newSelectedproviderName = provider.id">
配置
</v-btn>
</v-card-actions>
</v-card>
</v-col>
</v-row>
<v-dialog v-model="showproviderCfg" width="700">
<v-card>
<v-card-title>
<span class="text-h4">{{ newSelectedproviderName }} 配置</span>
</v-card-title>
<v-card-text>
<AstrBotConfig :iterable="newSelectedproviderConfig"
:metadata="metadata['provider_group']['metadata']" metadataKey="provider" />
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="newprovider" :loading="loading">
保存
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</v-card-text>
</v-card>
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack">
{{ save_message }}
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script>
import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
export default {
name: 'ProviderPage',
components: {
AstrBotConfig,
WaitingForRestart
},
data() {
return {
config_data: {},
fetched: false,
metadata: {},
showproviderCfg: false,
newSelectedproviderName: '',
newSelectedproviderConfig: {},
updatingMode: false,
loading: false,
save_message_snack: false,
save_message: "",
save_message_success: "",
}
},
mounted() {
this.getConfig();
},
methods: {
getConfig() {
// 获取配置
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
this.fetched = true
this.metadata = res.data.data.metadata;
}).catch((err) => {
save_message = err;
save_message_snack = true;
save_message_success = "error";
});
},
addFromDefaultConfigTmpl(index) {
// 从默认配置模板中添加
console.log(index);
this.newSelectedproviderName = index[0];
this.showproviderCfg = true;
this.updatingMode = false;
this.newSelectedproviderConfig = this.metadata['provider_group']['metadata']['provider'].config_template[index[0]];
},
newprovider() {
// 新建或者更新平台
this.loading = true;
if (this.updatingMode) {
axios.post('/api/config/provider/update', {
id: this.newSelectedproviderName,
config: this.newSelectedproviderConfig
}).then((res) => {
this.loading = false;
this.showproviderCfg = false;
this.getConfig();
// this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
this.updatingMode = false;
} else {
axios.post('/api/config/provider/new', this.newSelectedproviderConfig).then((res) => {
this.loading = false;
this.showproviderCfg = false;
this.getConfig();
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
},
deleteprovider(provider_id) {
// 删除平台
axios.post('/api/config/provider/delete', { id: provider_id }).then((res) => {
this.getConfig();
// this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
providerStatusChange(provider) {
// 平台状态改变
axios.post('/api/config/provider/update', {
id: provider.id,
config: provider
}).then((res) => {
this.getConfig();
// this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
</style>