feat: provider availability reachability improvements (#3708)

This commit is contained in:
邹永赫
2025-12-01 01:06:10 +09:00
committed by GitHub
parent c13c51f499
commit dcda871fc0
5 changed files with 276 additions and 20 deletions

2
.gitignore vendored
View File

@@ -48,3 +48,5 @@ astrbot.lock
chroma
venv/*
pytest.ini
AGENTS.md
IFLOW.md

View File

@@ -73,6 +73,7 @@ DEFAULT_CONFIG = {
"coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": True,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
@@ -2279,6 +2280,11 @@ CONFIG_METADATA_3 = {
"_special": "select_provider",
"hint": "留空代表不使用,可用于非多模态模型",
},
"provider_settings.reachability_check": {
"description": "提供商可达性检测",
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",

View File

@@ -36,6 +36,10 @@
"description": "Default Image Caption Model",
"hint": "Leave empty to disable; useful for non-multimodal models"
},
"reachability_check": {
"description": "Provider Reachability Check",
"hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens."
},
"image_caption_prompt": {
"description": "Image Caption Prompt"
}
@@ -453,4 +457,4 @@
}
}
}
}
}

View File

@@ -36,6 +36,10 @@
"description": "默认图片转述模型",
"hint": "留空代表不使用,可用于非多模态模型"
},
"reachability_check": {
"description": "提供商可达性检测",
"hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。"
},
"image_caption_prompt": {
"description": "图片转述提示词"
}

View File

@@ -1,14 +1,169 @@
import asyncio
import os
import re
from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_path
REACHABILITY_CHECK_TIMEOUT = 30.0
class ProviderCommands:
def __init__(self, context: star.Context):
self.context = context
def _log_reachability_failure(
self,
provider,
provider_capability_type: ProviderType | None,
err_code: str,
err_reason: str,
):
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
meta.id,
provider_capability_type.name if provider_capability_type else "unknown",
err_code,
err_reason,
)
async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)"""
meta = provider.meta()
provider_capability_type = meta.provider_type
try:
if provider_capability_type == ProviderType.CHAT_COMPLETION:
# 发送 "Ping" 测试对话
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if response is not None:
return True, None, None
err_code = "EMPTY_RESPONSE"
err_reason = "Provider returned empty response"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.EMBEDDING:
# 测试 Embedding
embedding_result = await asyncio.wait_for(
provider.get_embedding("health_check"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if (
isinstance(embedding_result, list)
and embedding_result
and all(isinstance(x, (int, float)) for x in embedding_result)
):
return True, None, None
err_code = "INVALID_EMBEDDING"
err_reason = "Provider returned invalid embedding"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
# 测试 TTS
audio_result = await asyncio.wait_for(
provider.get_audio("你好"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if isinstance(audio_result, str) and audio_result:
# 清理检测生成的临时音频文件,避免频繁检测时堆积
if os.path.isfile(audio_result):
try:
os.remove(audio_result)
except OSError as e:
logger.debug(
"Failed to cleanup TTS health check file %s: %s",
audio_result,
e,
)
return True, None, None
err_code = "INVALID_AUDIO"
err_reason = "Provider returned invalid audio"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
# 测试 STT
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
# 如果样本文件不存在,降级为检查是否实现了方法
return hasattr(provider, "get_text"), None, None
text_result = await asyncio.wait_for(
provider.get_text(sample_audio_path),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if isinstance(text_result, str) and text_result:
return True, None, None
err_code = "INVALID_TEXT"
err_reason = "Provider returned invalid text"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.RERANK:
# 测试 Rerank
if isinstance(provider, RerankProvider):
await asyncio.wait_for(
provider.rerank("Apple", documents=["apple", "banana"]),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
return True, None, None
err_code = "NOT_RERANK_PROVIDER"
err_reason = "Provider is not RerankProvider"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
else:
# 其他类型暂时视为通过,或者回退到 get_models
if hasattr(provider, "get_models"):
await asyncio.wait_for(
provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT
)
return True, None, None
return True, None, None # 未知类型默认通过
except asyncio.TimeoutError:
err_code = "TIMEOUT"
err_reason = "Reachability check timed out"
except Exception as exc:
err_code = (
getattr(exc, "status_code", None)
or getattr(exc, "code", None)
or getattr(exc, "error_code", None)
)
err_reason = str(exc)
if not err_code:
err_code = exc.__class__.__name__
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def provider(
self,
event: AstrMessageEvent,
@@ -17,46 +172,131 @@ class ProviderCommands:
):
"""查看或者切换 LLM Provider"""
umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
reachability_check_enabled = cfg.get("reachability_check", True)
if idx is None:
parts = ["## 载入的 LLM 提供商\n"]
for idx, llm in enumerate(self.context.get_all_providers()):
id_ = llm.meta().id
line = f"{idx + 1}. {id_} ({llm.meta().model})"
# 获取所有类型的提供商
llms = list(self.context.get_all_providers())
ttss = self.context.get_all_tts_providers()
stts = self.context.get_all_stt_providers()
# 构造待检测列表: [(provider, type_label), ...]
all_providers = []
all_providers.extend([(p, "llm") for p in llms])
all_providers.extend([(p, "tts") for p in ttss])
all_providers.extend([(p, "stt") for p in stts])
# 并发测试连通性
if reachability_check_enabled:
if all_providers:
await event.send(
MessageEventResult().message(
"正在进行提供商可达性测试,请稍候..."
)
)
check_results = await asyncio.gather(
*[self._test_provider_capability(p) for p, _ in all_providers],
return_exceptions=True,
)
else:
# 用 None 表示未检测
check_results = [None for _ in all_providers]
# 整合结果
display_data = []
for (p, p_type), reachable in zip(all_providers, check_results):
meta = p.meta()
id_ = meta.id
error_code = None
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
str(reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
elif isinstance(reachable, tuple):
reachable_flag, error_code, _ = reachable
else:
reachable_flag = reachable
# 根据类型构建显示名称
if p_type == "llm":
info = f"{id_} ({meta.model})"
else:
info = f"{id_}"
# 确定状态标记
if reachable_flag is True:
mark = ""
elif reachable_flag is False:
if error_code:
mark = f" ❌(错误码: {error_code})"
else:
mark = ""
else:
mark = "" # 不支持检测时不显示标记
display_data.append(
{
"type": p_type,
"info": info,
"mark": mark,
"provider": p,
}
)
# 分组输出
# 1. LLM
llm_data = [d for d in display_data if d["type"] == "llm"]
for i, d in enumerate(llm_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
provider_using = self.context.get_using_provider(umo=umo)
if provider_using and provider_using.meta().id == id_:
if (
provider_using
and provider_using.meta().id == d["provider"].meta().id
):
line += " (当前使用)"
parts.append(line + "\n")
tts_providers = self.context.get_all_tts_providers()
if tts_providers:
# 2. TTS
tts_data = [d for d in display_data if d["type"] == "tts"]
if tts_data:
parts.append("\n## 载入的 TTS 提供商\n")
for idx, tts in enumerate(tts_providers):
id_ = tts.meta().id
line = f"{idx + 1}. {id_}"
for i, d in enumerate(tts_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
tts_using = self.context.get_using_tts_provider(umo=umo)
if tts_using and tts_using.meta().id == id_:
if tts_using and tts_using.meta().id == d["provider"].meta().id:
line += " (当前使用)"
parts.append(line + "\n")
stt_providers = self.context.get_all_stt_providers()
if stt_providers:
# 3. STT
stt_data = [d for d in display_data if d["type"] == "stt"]
if stt_data:
parts.append("\n## 载入的 STT 提供商\n")
for idx, stt in enumerate(stt_providers):
id_ = stt.meta().id
line = f"{idx + 1}. {id_}"
for i, d in enumerate(stt_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
stt_using = self.context.get_using_stt_provider(umo=umo)
if stt_using and stt_using.meta().id == id_:
if stt_using and stt_using.meta().id == d["provider"].meta().id:
line += " (当前使用)"
parts.append(line + "\n")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
ret = "".join(parts)
if tts_providers:
if ttss:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
if stt_providers:
ret += "\n使用 /provider stt <切换> STT 提供商。"
if stts:
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
if not reachability_check_enabled:
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
event.set_result(MessageEventResult().message(ret))
elif idx == "tts":