feat: provider availability reachability improvements (#3708)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -48,3 +48,5 @@ astrbot.lock
|
|||||||
chroma
|
chroma
|
||||||
venv/*
|
venv/*
|
||||||
pytest.ini
|
pytest.ini
|
||||||
|
AGENTS.md
|
||||||
|
IFLOW.md
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ DEFAULT_CONFIG = {
|
|||||||
"coze_agent_runner_provider_id": "",
|
"coze_agent_runner_provider_id": "",
|
||||||
"dashscope_agent_runner_provider_id": "",
|
"dashscope_agent_runner_provider_id": "",
|
||||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||||
|
"reachability_check": True,
|
||||||
"max_agent_step": 30,
|
"max_agent_step": 30,
|
||||||
"tool_call_timeout": 60,
|
"tool_call_timeout": 60,
|
||||||
},
|
},
|
||||||
@@ -2279,6 +2280,11 @@ CONFIG_METADATA_3 = {
|
|||||||
"_special": "select_provider",
|
"_special": "select_provider",
|
||||||
"hint": "留空代表不使用,可用于非多模态模型",
|
"hint": "留空代表不使用,可用于非多模态模型",
|
||||||
},
|
},
|
||||||
|
"provider_settings.reachability_check": {
|
||||||
|
"description": "提供商可达性检测",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
|
||||||
|
},
|
||||||
"provider_stt_settings.enable": {
|
"provider_stt_settings.enable": {
|
||||||
"description": "启用语音转文本",
|
"description": "启用语音转文本",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
|
|||||||
@@ -36,6 +36,10 @@
|
|||||||
"description": "Default Image Caption Model",
|
"description": "Default Image Caption Model",
|
||||||
"hint": "Leave empty to disable; useful for non-multimodal models"
|
"hint": "Leave empty to disable; useful for non-multimodal models"
|
||||||
},
|
},
|
||||||
|
"reachability_check": {
|
||||||
|
"description": "Provider Reachability Check",
|
||||||
|
"hint": "When running the /provider command, test provider connectivity in parallel. This actively pings models and may consume extra tokens."
|
||||||
|
},
|
||||||
"image_caption_prompt": {
|
"image_caption_prompt": {
|
||||||
"description": "Image Caption Prompt"
|
"description": "Image Caption Prompt"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,10 @@
|
|||||||
"description": "默认图片转述模型",
|
"description": "默认图片转述模型",
|
||||||
"hint": "留空代表不使用,可用于非多模态模型"
|
"hint": "留空代表不使用,可用于非多模态模型"
|
||||||
},
|
},
|
||||||
|
"reachability_check": {
|
||||||
|
"description": "提供商可达性检测",
|
||||||
|
"hint": "/provider 命令列出模型时并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。"
|
||||||
|
},
|
||||||
"image_caption_prompt": {
|
"image_caption_prompt": {
|
||||||
"description": "图片转述提示词"
|
"description": "图片转述提示词"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,169 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
from astrbot.api import star
|
from astrbot.api import star
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
|
from astrbot.core.provider.provider import RerankProvider
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||||
|
|
||||||
|
REACHABILITY_CHECK_TIMEOUT = 30.0
|
||||||
|
|
||||||
|
|
||||||
class ProviderCommands:
|
class ProviderCommands:
|
||||||
def __init__(self, context: star.Context):
|
def __init__(self, context: star.Context):
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
|
def _log_reachability_failure(
|
||||||
|
self,
|
||||||
|
provider,
|
||||||
|
provider_capability_type: ProviderType | None,
|
||||||
|
err_code: str,
|
||||||
|
err_reason: str,
|
||||||
|
):
|
||||||
|
"""记录不可达原因到日志。"""
|
||||||
|
meta = provider.meta()
|
||||||
|
logger.warning(
|
||||||
|
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
|
||||||
|
meta.id,
|
||||||
|
provider_capability_type.name if provider_capability_type else "unknown",
|
||||||
|
err_code,
|
||||||
|
err_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _test_provider_capability(self, provider):
|
||||||
|
"""测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)"""
|
||||||
|
meta = provider.meta()
|
||||||
|
provider_capability_type = meta.provider_type
|
||||||
|
|
||||||
|
try:
|
||||||
|
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
||||||
|
# 发送 "Ping" 测试对话
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
provider.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||||
|
timeout=REACHABILITY_CHECK_TIMEOUT,
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
return True, None, None
|
||||||
|
err_code = "EMPTY_RESPONSE"
|
||||||
|
err_reason = "Provider returned empty response"
|
||||||
|
self._log_reachability_failure(
|
||||||
|
provider, provider_capability_type, err_code, err_reason
|
||||||
|
)
|
||||||
|
return False, err_code, err_reason
|
||||||
|
|
||||||
|
elif provider_capability_type == ProviderType.EMBEDDING:
|
||||||
|
# 测试 Embedding
|
||||||
|
embedding_result = await asyncio.wait_for(
|
||||||
|
provider.get_embedding("health_check"),
|
||||||
|
timeout=REACHABILITY_CHECK_TIMEOUT,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(embedding_result, list)
|
||||||
|
and embedding_result
|
||||||
|
and all(isinstance(x, (int, float)) for x in embedding_result)
|
||||||
|
):
|
||||||
|
return True, None, None
|
||||||
|
err_code = "INVALID_EMBEDDING"
|
||||||
|
err_reason = "Provider returned invalid embedding"
|
||||||
|
self._log_reachability_failure(
|
||||||
|
provider, provider_capability_type, err_code, err_reason
|
||||||
|
)
|
||||||
|
return False, err_code, err_reason
|
||||||
|
|
||||||
|
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
||||||
|
# 测试 TTS
|
||||||
|
audio_result = await asyncio.wait_for(
|
||||||
|
provider.get_audio("你好"),
|
||||||
|
timeout=REACHABILITY_CHECK_TIMEOUT,
|
||||||
|
)
|
||||||
|
if isinstance(audio_result, str) and audio_result:
|
||||||
|
# 清理检测生成的临时音频文件,避免频繁检测时堆积
|
||||||
|
if os.path.isfile(audio_result):
|
||||||
|
try:
|
||||||
|
os.remove(audio_result)
|
||||||
|
except OSError as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to cleanup TTS health check file %s: %s",
|
||||||
|
audio_result,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return True, None, None
|
||||||
|
err_code = "INVALID_AUDIO"
|
||||||
|
err_reason = "Provider returned invalid audio"
|
||||||
|
self._log_reachability_failure(
|
||||||
|
provider, provider_capability_type, err_code, err_reason
|
||||||
|
)
|
||||||
|
return False, err_code, err_reason
|
||||||
|
|
||||||
|
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
||||||
|
# 测试 STT
|
||||||
|
sample_audio_path = os.path.join(
|
||||||
|
get_astrbot_path(),
|
||||||
|
"samples",
|
||||||
|
"stt_health_check.wav",
|
||||||
|
)
|
||||||
|
if not os.path.exists(sample_audio_path):
|
||||||
|
# 如果样本文件不存在,降级为检查是否实现了方法
|
||||||
|
return hasattr(provider, "get_text"), None, None
|
||||||
|
|
||||||
|
text_result = await asyncio.wait_for(
|
||||||
|
provider.get_text(sample_audio_path),
|
||||||
|
timeout=REACHABILITY_CHECK_TIMEOUT,
|
||||||
|
)
|
||||||
|
if isinstance(text_result, str) and text_result:
|
||||||
|
return True, None, None
|
||||||
|
err_code = "INVALID_TEXT"
|
||||||
|
err_reason = "Provider returned invalid text"
|
||||||
|
self._log_reachability_failure(
|
||||||
|
provider, provider_capability_type, err_code, err_reason
|
||||||
|
)
|
||||||
|
return False, err_code, err_reason
|
||||||
|
|
||||||
|
elif provider_capability_type == ProviderType.RERANK:
|
||||||
|
# 测试 Rerank
|
||||||
|
if isinstance(provider, RerankProvider):
|
||||||
|
await asyncio.wait_for(
|
||||||
|
provider.rerank("Apple", documents=["apple", "banana"]),
|
||||||
|
timeout=REACHABILITY_CHECK_TIMEOUT,
|
||||||
|
)
|
||||||
|
return True, None, None
|
||||||
|
err_code = "NOT_RERANK_PROVIDER"
|
||||||
|
err_reason = "Provider is not RerankProvider"
|
||||||
|
self._log_reachability_failure(
|
||||||
|
provider, provider_capability_type, err_code, err_reason
|
||||||
|
)
|
||||||
|
return False, err_code, err_reason
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 其他类型暂时视为通过,或者回退到 get_models
|
||||||
|
if hasattr(provider, "get_models"):
|
||||||
|
await asyncio.wait_for(
|
||||||
|
provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT
|
||||||
|
)
|
||||||
|
return True, None, None
|
||||||
|
return True, None, None # 未知类型默认通过
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
err_code = "TIMEOUT"
|
||||||
|
err_reason = "Reachability check timed out"
|
||||||
|
except Exception as exc:
|
||||||
|
err_code = (
|
||||||
|
getattr(exc, "status_code", None)
|
||||||
|
or getattr(exc, "code", None)
|
||||||
|
or getattr(exc, "error_code", None)
|
||||||
|
)
|
||||||
|
err_reason = str(exc)
|
||||||
|
if not err_code:
|
||||||
|
err_code = exc.__class__.__name__
|
||||||
|
|
||||||
|
self._log_reachability_failure(
|
||||||
|
provider, provider_capability_type, err_code, err_reason
|
||||||
|
)
|
||||||
|
return False, err_code, err_reason
|
||||||
|
|
||||||
async def provider(
|
async def provider(
|
||||||
self,
|
self,
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
@@ -17,46 +172,131 @@ class ProviderCommands:
|
|||||||
):
|
):
|
||||||
"""查看或者切换 LLM Provider"""
|
"""查看或者切换 LLM Provider"""
|
||||||
umo = event.unified_msg_origin
|
umo = event.unified_msg_origin
|
||||||
|
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||||
|
reachability_check_enabled = cfg.get("reachability_check", True)
|
||||||
|
|
||||||
if idx is None:
|
if idx is None:
|
||||||
parts = ["## 载入的 LLM 提供商\n"]
|
parts = ["## 载入的 LLM 提供商\n"]
|
||||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
|
||||||
id_ = llm.meta().id
|
# 获取所有类型的提供商
|
||||||
line = f"{idx + 1}. {id_} ({llm.meta().model})"
|
llms = list(self.context.get_all_providers())
|
||||||
|
ttss = self.context.get_all_tts_providers()
|
||||||
|
stts = self.context.get_all_stt_providers()
|
||||||
|
|
||||||
|
# 构造待检测列表: [(provider, type_label), ...]
|
||||||
|
all_providers = []
|
||||||
|
all_providers.extend([(p, "llm") for p in llms])
|
||||||
|
all_providers.extend([(p, "tts") for p in ttss])
|
||||||
|
all_providers.extend([(p, "stt") for p in stts])
|
||||||
|
|
||||||
|
# 并发测试连通性
|
||||||
|
if reachability_check_enabled:
|
||||||
|
if all_providers:
|
||||||
|
await event.send(
|
||||||
|
MessageEventResult().message(
|
||||||
|
"正在进行提供商可达性测试,请稍候..."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
check_results = await asyncio.gather(
|
||||||
|
*[self._test_provider_capability(p) for p, _ in all_providers],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 用 None 表示未检测
|
||||||
|
check_results = [None for _ in all_providers]
|
||||||
|
|
||||||
|
# 整合结果
|
||||||
|
display_data = []
|
||||||
|
for (p, p_type), reachable in zip(all_providers, check_results):
|
||||||
|
meta = p.meta()
|
||||||
|
id_ = meta.id
|
||||||
|
error_code = None
|
||||||
|
|
||||||
|
if isinstance(reachable, Exception):
|
||||||
|
# 异常情况下兜底处理,避免单个 provider 导致列表失败
|
||||||
|
self._log_reachability_failure(
|
||||||
|
p,
|
||||||
|
None,
|
||||||
|
reachable.__class__.__name__,
|
||||||
|
str(reachable),
|
||||||
|
)
|
||||||
|
reachable_flag = False
|
||||||
|
error_code = reachable.__class__.__name__
|
||||||
|
elif isinstance(reachable, tuple):
|
||||||
|
reachable_flag, error_code, _ = reachable
|
||||||
|
else:
|
||||||
|
reachable_flag = reachable
|
||||||
|
|
||||||
|
# 根据类型构建显示名称
|
||||||
|
if p_type == "llm":
|
||||||
|
info = f"{id_} ({meta.model})"
|
||||||
|
else:
|
||||||
|
info = f"{id_}"
|
||||||
|
|
||||||
|
# 确定状态标记
|
||||||
|
if reachable_flag is True:
|
||||||
|
mark = " ✅"
|
||||||
|
elif reachable_flag is False:
|
||||||
|
if error_code:
|
||||||
|
mark = f" ❌(错误码: {error_code})"
|
||||||
|
else:
|
||||||
|
mark = " ❌"
|
||||||
|
else:
|
||||||
|
mark = "" # 不支持检测时不显示标记
|
||||||
|
|
||||||
|
display_data.append(
|
||||||
|
{
|
||||||
|
"type": p_type,
|
||||||
|
"info": info,
|
||||||
|
"mark": mark,
|
||||||
|
"provider": p,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 分组输出
|
||||||
|
# 1. LLM
|
||||||
|
llm_data = [d for d in display_data if d["type"] == "llm"]
|
||||||
|
for i, d in enumerate(llm_data):
|
||||||
|
line = f"{i + 1}. {d['info']}{d['mark']}"
|
||||||
provider_using = self.context.get_using_provider(umo=umo)
|
provider_using = self.context.get_using_provider(umo=umo)
|
||||||
if provider_using and provider_using.meta().id == id_:
|
if (
|
||||||
|
provider_using
|
||||||
|
and provider_using.meta().id == d["provider"].meta().id
|
||||||
|
):
|
||||||
line += " (当前使用)"
|
line += " (当前使用)"
|
||||||
parts.append(line + "\n")
|
parts.append(line + "\n")
|
||||||
|
|
||||||
tts_providers = self.context.get_all_tts_providers()
|
# 2. TTS
|
||||||
if tts_providers:
|
tts_data = [d for d in display_data if d["type"] == "tts"]
|
||||||
|
if tts_data:
|
||||||
parts.append("\n## 载入的 TTS 提供商\n")
|
parts.append("\n## 载入的 TTS 提供商\n")
|
||||||
for idx, tts in enumerate(tts_providers):
|
for i, d in enumerate(tts_data):
|
||||||
id_ = tts.meta().id
|
line = f"{i + 1}. {d['info']}{d['mark']}"
|
||||||
line = f"{idx + 1}. {id_}"
|
|
||||||
tts_using = self.context.get_using_tts_provider(umo=umo)
|
tts_using = self.context.get_using_tts_provider(umo=umo)
|
||||||
if tts_using and tts_using.meta().id == id_:
|
if tts_using and tts_using.meta().id == d["provider"].meta().id:
|
||||||
line += " (当前使用)"
|
line += " (当前使用)"
|
||||||
parts.append(line + "\n")
|
parts.append(line + "\n")
|
||||||
|
|
||||||
stt_providers = self.context.get_all_stt_providers()
|
# 3. STT
|
||||||
if stt_providers:
|
stt_data = [d for d in display_data if d["type"] == "stt"]
|
||||||
|
if stt_data:
|
||||||
parts.append("\n## 载入的 STT 提供商\n")
|
parts.append("\n## 载入的 STT 提供商\n")
|
||||||
for idx, stt in enumerate(stt_providers):
|
for i, d in enumerate(stt_data):
|
||||||
id_ = stt.meta().id
|
line = f"{i + 1}. {d['info']}{d['mark']}"
|
||||||
line = f"{idx + 1}. {id_}"
|
|
||||||
stt_using = self.context.get_using_stt_provider(umo=umo)
|
stt_using = self.context.get_using_stt_provider(umo=umo)
|
||||||
if stt_using and stt_using.meta().id == id_:
|
if stt_using and stt_using.meta().id == d["provider"].meta().id:
|
||||||
line += " (当前使用)"
|
line += " (当前使用)"
|
||||||
parts.append(line + "\n")
|
parts.append(line + "\n")
|
||||||
|
|
||||||
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
|
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
|
||||||
ret = "".join(parts)
|
ret = "".join(parts)
|
||||||
|
|
||||||
if tts_providers:
|
if ttss:
|
||||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||||
if stt_providers:
|
if stts:
|
||||||
ret += "\n使用 /provider stt <切换> STT 提供商。"
|
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
|
||||||
|
if not reachability_check_enabled:
|
||||||
|
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
|
||||||
|
|
||||||
event.set_result(MessageEventResult().message(ret))
|
event.set_result(MessageEventResult().message(ret))
|
||||||
elif idx == "tts":
|
elif idx == "tts":
|
||||||
|
|||||||
Reference in New Issue
Block a user