Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
26467fbe22 feat: implement unified provider availability testing across components
- Added a `test` method to each provider class to standardize availability checks.
- Updated the dashboard and command routes to utilize the new `test` method for provider reachability verification, simplifying the logic and improving maintainability.
- Removed redundant reachability check logic from the command handler.
2025-12-01 00:39:14 +08:00
3 changed files with 88 additions and 283 deletions

View File

@@ -1,5 +1,6 @@
import abc
import asyncio
import os
from collections.abc import AsyncGenerator
from astrbot.core.agent.message import Message
@@ -11,6 +12,7 @@ from astrbot.core.provider.entities import (
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
class AbstractProvider(abc.ABC):
@@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC):
)
return meta
async def test(self) -> bool:
"""test the provider is a
Returns:
bool: the provider is available
"""
return True
class Provider(AbstractProvider):
"""Chat Provider"""
@@ -165,6 +175,16 @@ class Provider(AbstractProvider):
return dicts
async def test(self, timeout: float = 45.0) -> bool:
try:
response = await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
return response is not None
except Exception:
return False
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -177,6 +197,20 @@ class STTProvider(AbstractProvider):
"""获取音频的文本"""
raise NotImplementedError
async def test(self) -> bool:
try:
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
return False
text_result = await self.get_text(sample_audio_path)
return isinstance(text_result, str) and bool(text_result)
except Exception:
return False
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -189,6 +223,13 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError
async def test(self) -> bool:
try:
audio_result = await self.get_audio("hi")
return isinstance(audio_result, str) and bool(audio_result)
except Exception:
return False
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -211,6 +252,15 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度"""
...
async def test(self) -> bool:
try:
embedding_result = await self.get_embedding("health_check")
return isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
)
except Exception:
return False
async def get_embeddings_batch(
self,
texts: list[str],
@@ -294,3 +344,10 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...
async def test(self) -> bool:
try:
await self.rerank("Apple", documents=["apple", "banana"])
return True
except Exception:
return False

View File

@@ -18,11 +18,8 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core.utils.astrbot_path import get_astrbot_path
from .route import Response, Route, RouteContext
@@ -356,169 +353,26 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
)
if provider_capability_type == ProviderType.CHAT_COMPLETION:
try:
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=45.0,
)
logger.debug(
f"Received response from {status_info['name']}: {response}",
)
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
)
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {e!s}"
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
try:
# For TTS, we can call the get_audio method with a short prompt.
audio_result = await provider.get_audio("你好")
if isinstance(audio_result, str) and audio_result:
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {e!s}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}",
)
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
)
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
)
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
)
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {e!s}"
elif provider_capability_type == ProviderType.RERANK:
try:
assert isinstance(provider, RerankProvider)
await provider.rerank("Apple", documents=["apple", "banana"])
try:
result = await provider.test()
if result:
status_info["status"] = "available"
except Exception as e:
logger.error(
f"Error testing rerank provider {provider_name}: {e}",
exc_info=True,
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
)
status_info["status"] = "unavailable"
status_info["error"] = f"Rerank test failed: {e!s}"
else:
logger.debug(
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}",
else:
status_info["error"] = "Provider test returned False."
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
)
status_info["status"] = "available"
status_info["error"] = (
"This provider type is not tested and is assumed to be available."
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
return status_info

View File

@@ -1,15 +1,10 @@
import asyncio
import os
import re
from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_path
REACHABILITY_CHECK_TIMEOUT = 30.0
class ProviderCommands:
@@ -34,121 +29,20 @@ class ProviderCommands:
)
async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)"""
"""测试单个 provider 的可用性"""
meta = provider.meta()
provider_capability_type = meta.provider_type
try:
if provider_capability_type == ProviderType.CHAT_COMPLETION:
# 发送 "Ping" 测试对话
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if response is not None:
return True, None, None
err_code = "EMPTY_RESPONSE"
err_reason = "Provider returned empty response"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.EMBEDDING:
# 测试 Embedding
embedding_result = await asyncio.wait_for(
provider.get_embedding("health_check"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if (
isinstance(embedding_result, list)
and embedding_result
and all(isinstance(x, (int, float)) for x in embedding_result)
):
return True, None, None
err_code = "INVALID_EMBEDDING"
err_reason = "Provider returned invalid embedding"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
# 测试 TTS
audio_result = await asyncio.wait_for(
provider.get_audio("你好"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if isinstance(audio_result, str) and audio_result:
# 清理检测生成的临时音频文件,避免频繁检测时堆积
if os.path.isfile(audio_result):
try:
os.remove(audio_result)
except OSError as e:
logger.debug(
"Failed to cleanup TTS health check file %s: %s",
audio_result,
e,
)
return True, None, None
err_code = "INVALID_AUDIO"
err_reason = "Provider returned invalid audio"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
# 测试 STT
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
# 如果样本文件不存在,降级为检查是否实现了方法
return hasattr(provider, "get_text"), None, None
text_result = await asyncio.wait_for(
provider.get_text(sample_audio_path),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if isinstance(text_result, str) and text_result:
return True, None, None
err_code = "INVALID_TEXT"
err_reason = "Provider returned invalid text"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.RERANK:
# 测试 Rerank
if isinstance(provider, RerankProvider):
await asyncio.wait_for(
provider.rerank("Apple", documents=["apple", "banana"]),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
return True, None, None
err_code = "NOT_RERANK_PROVIDER"
err_reason = "Provider is not RerankProvider"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
else:
# 其他类型暂时视为通过,或者回退到 get_models
if hasattr(provider, "get_models"):
await asyncio.wait_for(
provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT
)
return True, None, None
return True, None, None # 未知类型默认通过
except asyncio.TimeoutError:
err_code = "TIMEOUT"
err_reason = "Reachability check timed out"
result = await provider.test()
if result:
return True, None, None
err_code = "TEST_FAILED"
err_reason = "Provider test returned False"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
except Exception as exc:
err_code = (
getattr(exc, "status_code", None)
@@ -159,10 +53,10 @@ class ProviderCommands:
if not err_code:
err_code = exc.__class__.__name__
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def provider(
self,