From dcda871fc00d4e7ef08b6c29d1dcb92c18a1bf7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <62183434+zouyonghe@users.noreply.github.com> Date: Mon, 1 Dec 2025 01:06:10 +0900 Subject: [PATCH] feat: provider availability reachability improvements (#3708) --- .gitignore | 2 + astrbot/core/config/default.py | 6 + .../en-US/features/config-metadata.json | 6 +- .../zh-CN/features/config-metadata.json | 4 + packages/astrbot/commands/provider.py | 278 ++++++++++++++++-- 5 files changed, 276 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 0b3686f8..9472296b 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,5 @@ astrbot.lock chroma venv/* pytest.ini +AGENTS.md +IFLOW.md diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index f3a4a7d2..7e71626c 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -73,6 +73,7 @@ DEFAULT_CONFIG = { "coze_agent_runner_provider_id": "", "dashscope_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", + "reachability_check": True, "max_agent_step": 30, "tool_call_timeout": 60, }, @@ -2279,6 +2280,11 @@ CONFIG_METADATA_3 = { "_special": "select_provider", "hint": "留空代表不使用,可用于非多模态模型", }, + "provider_settings.reachability_check": { + "description": "提供商可达性检测", + "type": "bool", + "hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。", + }, "provider_stt_settings.enable": { "description": "启用语音转文本", "type": "bool", diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index a4a72f61..2ae52074 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -36,6 +36,10 @@ "description": "Default Image Caption Model", "hint": "Leave empty to disable; useful for non-multimodal models" }, + "reachability_check": { + "description": "Provider Reachability Check", + "hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens." + }, "image_caption_prompt": { "description": "Image Caption Prompt" } @@ -453,4 +457,4 @@ } } } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 0aee49df..0c1046c0 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -36,6 +36,10 @@ "description": "默认图片转述模型", "hint": "留空代表不使用,可用于非多模态模型" }, + "reachability_check": { + "description": "提供商可达性检测", + "hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。" + }, "image_caption_prompt": { "description": "图片转述提示词" } diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py index 8db7324e..d306c41a 100644 --- a/packages/astrbot/commands/provider.py +++ b/packages/astrbot/commands/provider.py @@ -1,14 +1,169 @@ +import asyncio +import os import re +from astrbot import logger from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.provider import RerankProvider +from astrbot.core.utils.astrbot_path import get_astrbot_path + +REACHABILITY_CHECK_TIMEOUT = 30.0 class ProviderCommands: def __init__(self, context: star.Context): self.context = context + def _log_reachability_failure( + self, + provider, + provider_capability_type: ProviderType | None, + err_code: str, + err_reason: str, + ): + """记录不可达原因到日志。""" + meta = provider.meta() + logger.warning( + "Provider reachability check failed: id=%s type=%s code=%s reason=%s", + meta.id, + provider_capability_type.name if provider_capability_type else "unknown", + err_code, + err_reason, + ) + + async def _test_provider_capability(self, provider): + """测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)""" + meta = provider.meta() + provider_capability_type = meta.provider_type + + try: + if provider_capability_type == ProviderType.CHAT_COMPLETION: + # 发送 "Ping" 测试对话 + response = await asyncio.wait_for( + provider.text_chat(prompt="REPLY `PONG` ONLY"), + timeout=REACHABILITY_CHECK_TIMEOUT, + ) + if response is not None: + return True, None, None + err_code = "EMPTY_RESPONSE" + err_reason = "Provider returned empty response" + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason + + elif provider_capability_type == ProviderType.EMBEDDING: + # 测试 Embedding + embedding_result = await asyncio.wait_for( + provider.get_embedding("health_check"), + timeout=REACHABILITY_CHECK_TIMEOUT, + ) + if ( + isinstance(embedding_result, list) + and embedding_result + and all(isinstance(x, (int, float)) for x in embedding_result) + ): + return True, None, None + err_code = "INVALID_EMBEDDING" + err_reason = "Provider returned invalid embedding" + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason + + elif provider_capability_type == ProviderType.TEXT_TO_SPEECH: + # 测试 TTS + audio_result = await asyncio.wait_for( + provider.get_audio("你好"), + timeout=REACHABILITY_CHECK_TIMEOUT, + ) + if isinstance(audio_result, str) and audio_result: + # 清理检测生成的临时音频文件,避免频繁检测时堆积 + if os.path.isfile(audio_result): + try: + os.remove(audio_result) + except OSError as e: + logger.debug( + "Failed to cleanup TTS health check file %s: %s", + audio_result, + e, + ) + return True, None, None + err_code = "INVALID_AUDIO" + err_reason = "Provider returned invalid audio" + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason + + elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: + # 测试 STT + sample_audio_path = os.path.join( + get_astrbot_path(), + "samples", + "stt_health_check.wav", + ) + if not os.path.exists(sample_audio_path): + # 如果样本文件不存在,降级为检查是否实现了方法 + return hasattr(provider, "get_text"), None, None + + text_result = await asyncio.wait_for( + provider.get_text(sample_audio_path), + timeout=REACHABILITY_CHECK_TIMEOUT, + ) + if isinstance(text_result, str) and text_result: + return True, None, None + err_code = "INVALID_TEXT" + err_reason = "Provider returned invalid text" + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason + + elif provider_capability_type == ProviderType.RERANK: + # 测试 Rerank + if isinstance(provider, RerankProvider): + await asyncio.wait_for( + provider.rerank("Apple", documents=["apple", "banana"]), + timeout=REACHABILITY_CHECK_TIMEOUT, + ) + return True, None, None + err_code = "NOT_RERANK_PROVIDER" + err_reason = "Provider is not RerankProvider" + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason + + else: + # 其他类型暂时视为通过,或者回退到 get_models + if hasattr(provider, "get_models"): + await asyncio.wait_for( + provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT + ) + return True, None, None + return True, None, None # 未知类型默认通过 + + except asyncio.TimeoutError: + err_code = "TIMEOUT" + err_reason = "Reachability check timed out" + except Exception as exc: + err_code = ( + getattr(exc, "status_code", None) + or getattr(exc, "code", None) + or getattr(exc, "error_code", None) + ) + err_reason = str(exc) + if not err_code: + err_code = exc.__class__.__name__ + + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason + async def provider( self, event: AstrMessageEvent, @@ -17,46 +172,131 @@ class ProviderCommands: ): """查看或者切换 LLM Provider""" umo = event.unified_msg_origin + cfg = self.context.get_config(umo).get("provider_settings", {}) + reachability_check_enabled = cfg.get("reachability_check", True) if idx is None: parts = ["## 载入的 LLM 提供商\n"] - for idx, llm in enumerate(self.context.get_all_providers()): - id_ = llm.meta().id - line = f"{idx + 1}. {id_} ({llm.meta().model})" + + # 获取所有类型的提供商 + llms = list(self.context.get_all_providers()) + ttss = self.context.get_all_tts_providers() + stts = self.context.get_all_stt_providers() + + # 构造待检测列表: [(provider, type_label), ...] + all_providers = [] + all_providers.extend([(p, "llm") for p in llms]) + all_providers.extend([(p, "tts") for p in ttss]) + all_providers.extend([(p, "stt") for p in stts]) + + # 并发测试连通性 + if reachability_check_enabled: + if all_providers: + await event.send( + MessageEventResult().message( + "正在进行提供商可达性测试,请稍候..." + ) + ) + check_results = await asyncio.gather( + *[self._test_provider_capability(p) for p, _ in all_providers], + return_exceptions=True, + ) + else: + # 用 None 表示未检测 + check_results = [None for _ in all_providers] + + # 整合结果 + display_data = [] + for (p, p_type), reachable in zip(all_providers, check_results): + meta = p.meta() + id_ = meta.id + error_code = None + + if isinstance(reachable, Exception): + # 异常情况下兜底处理,避免单个 provider 导致列表失败 + self._log_reachability_failure( + p, + None, + reachable.__class__.__name__, + str(reachable), + ) + reachable_flag = False + error_code = reachable.__class__.__name__ + elif isinstance(reachable, tuple): + reachable_flag, error_code, _ = reachable + else: + reachable_flag = reachable + + # 根据类型构建显示名称 + if p_type == "llm": + info = f"{id_} ({meta.model})" + else: + info = f"{id_}" + + # 确定状态标记 + if reachable_flag is True: + mark = " ✅" + elif reachable_flag is False: + if error_code: + mark = f" ❌(错误码: {error_code})" + else: + mark = " ❌" + else: + mark = "" # 不支持检测时不显示标记 + + display_data.append( + { + "type": p_type, + "info": info, + "mark": mark, + "provider": p, + } + ) + + # 分组输出 + # 1. LLM + llm_data = [d for d in display_data if d["type"] == "llm"] + for i, d in enumerate(llm_data): + line = f"{i + 1}. {d['info']}{d['mark']}" provider_using = self.context.get_using_provider(umo=umo) - if provider_using and provider_using.meta().id == id_: + if ( + provider_using + and provider_using.meta().id == d["provider"].meta().id + ): line += " (当前使用)" parts.append(line + "\n") - tts_providers = self.context.get_all_tts_providers() - if tts_providers: + # 2. TTS + tts_data = [d for d in display_data if d["type"] == "tts"] + if tts_data: parts.append("\n## 载入的 TTS 提供商\n") - for idx, tts in enumerate(tts_providers): - id_ = tts.meta().id - line = f"{idx + 1}. {id_}" + for i, d in enumerate(tts_data): + line = f"{i + 1}. {d['info']}{d['mark']}" tts_using = self.context.get_using_tts_provider(umo=umo) - if tts_using and tts_using.meta().id == id_: + if tts_using and tts_using.meta().id == d["provider"].meta().id: line += " (当前使用)" parts.append(line + "\n") - stt_providers = self.context.get_all_stt_providers() - if stt_providers: + # 3. STT + stt_data = [d for d in display_data if d["type"] == "stt"] + if stt_data: parts.append("\n## 载入的 STT 提供商\n") - for idx, stt in enumerate(stt_providers): - id_ = stt.meta().id - line = f"{idx + 1}. {id_}" + for i, d in enumerate(stt_data): + line = f"{i + 1}. {d['info']}{d['mark']}" stt_using = self.context.get_using_stt_provider(umo=umo) - if stt_using and stt_using.meta().id == id_: + if stt_using and stt_using.meta().id == d["provider"].meta().id: line += " (当前使用)" parts.append(line + "\n") parts.append("\n使用 /provider <序号> 切换 LLM 提供商。") ret = "".join(parts) - if tts_providers: + if ttss: ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" - if stt_providers: - ret += "\n使用 /provider stt <切换> STT 提供商。" + if stts: + ret += "\n使用 /provider stt <序号> 切换 STT 提供商。" + if not reachability_check_enabled: + ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。" event.set_result(MessageEventResult().message(ret)) elif idx == "tts":