Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
26467fbe22 feat: implement unified provider availability testing across components
- Added a `test` method to each provider class to standardize availability checks.
- Updated the dashboard and command routes to utilize the new `test` method for provider reachability verification, simplifying the logic and improving maintainability.
- Removed redundant reachability check logic from the command handler.
2025-12-01 00:39:14 +08:00
3 changed files with 88 additions and 283 deletions

View File

@@ -1,5 +1,6 @@
import abc import abc
import asyncio import asyncio
import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
@@ -11,6 +12,7 @@ from astrbot.core.provider.entities import (
ToolCallsResult, ToolCallsResult,
) )
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
@@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC):
) )
return meta return meta
async def test(self) -> bool:
"""test the provider is a
Returns:
bool: the provider is available
"""
return True
class Provider(AbstractProvider): class Provider(AbstractProvider):
"""Chat Provider""" """Chat Provider"""
@@ -165,6 +175,16 @@ class Provider(AbstractProvider):
return dicts return dicts
async def test(self, timeout: float = 45.0) -> bool:
try:
response = await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
return response is not None
except Exception:
return False
class STTProvider(AbstractProvider): class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -177,6 +197,20 @@ class STTProvider(AbstractProvider):
"""获取音频的文本""" """获取音频的文本"""
raise NotImplementedError raise NotImplementedError
async def test(self) -> bool:
try:
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
return False
text_result = await self.get_text(sample_audio_path)
return isinstance(text_result, str) and bool(text_result)
except Exception:
return False
class TTSProvider(AbstractProvider): class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -189,6 +223,13 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径""" """获取文本的音频,返回音频文件路径"""
raise NotImplementedError raise NotImplementedError
async def test(self) -> bool:
try:
audio_result = await self.get_audio("hi")
return isinstance(audio_result, str) and bool(audio_result)
except Exception:
return False
class EmbeddingProvider(AbstractProvider): class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None: def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -211,6 +252,15 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度""" """获取向量的维度"""
... ...
async def test(self) -> bool:
try:
embedding_result = await self.get_embedding("health_check")
return isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
)
except Exception:
return False
async def get_embeddings_batch( async def get_embeddings_batch(
self, self,
texts: list[str], texts: list[str],
@@ -294,3 +344,10 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]: ) -> list[RerankResult]:
"""获取查询和文档的重排序分数""" """获取查询和文档的重排序分数"""
... ...
async def test(self) -> bool:
try:
await self.rerank("Apple", documents=["apple", "banana"])
return True
except Exception:
return False

View File

@@ -18,11 +18,8 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_registry
from astrbot.core.utils.astrbot_path import get_astrbot_path
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -356,169 +353,26 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})",
) )
if provider_capability_type == ProviderType.CHAT_COMPLETION: try:
try: result = await provider.test()
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") if result:
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=45.0,
)
logger.debug(
f"Received response from {status_info['name']}: {response}",
)
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if (
hasattr(response, "completion_text")
and response.completion_text
):
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'",
)
else:
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.",
)
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.",
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
)
elif provider_capability_type == ProviderType.EMBEDDING:
try:
# For embedding, we can call the get_embedding method with a short prompt.
embedding_result = await provider.get_embedding("health_check")
if isinstance(embedding_result, list) and (
not embedding_result or isinstance(embedding_result[0], float)
):
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"Embedding test failed: unexpected result type {type(embedding_result)}"
)
except Exception as e:
logger.error(
f"Error testing embedding provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Embedding test failed: {e!s}"
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
try:
# For TTS, we can call the get_audio method with a short prompt.
audio_result = await provider.get_audio("你好")
if isinstance(audio_result, str) and audio_result:
status_info["status"] = "available"
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"TTS test failed: unexpected result type {type(audio_result)}"
)
except Exception as e:
logger.error(
f"Error testing TTS provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"TTS test failed: {e!s}"
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
try:
logger.debug(
f"Sending health check audio to provider: {status_info['name']}",
)
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (
"STT test failed: sample audio file not found."
)
logger.warning(
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}",
)
else:
text_result = await provider.get_text(sample_audio_path)
if isinstance(text_result, str) and text_result:
status_info["status"] = "available"
snippet = (
text_result[:70] + "..."
if len(text_result) > 70
else text_result
)
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'",
)
else:
status_info["status"] = "unavailable"
status_info["error"] = (
f"STT test failed: unexpected result type {type(text_result)}"
)
logger.warning(
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}",
)
except Exception as e:
logger.error(
f"Error testing STT provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {e!s}"
elif provider_capability_type == ProviderType.RERANK:
try:
assert isinstance(provider, RerankProvider)
await provider.rerank("Apple", documents=["apple", "banana"])
status_info["status"] = "available" status_info["status"] = "available"
except Exception as e: logger.info(
logger.error( f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
f"Error testing rerank provider {provider_name}: {e}",
exc_info=True,
) )
status_info["status"] = "unavailable" else:
status_info["error"] = f"Rerank test failed: {e!s}" status_info["error"] = "Provider test returned False."
logger.warning(
else: f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
logger.debug( )
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}", except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}",
) )
status_info["status"] = "available" logger.debug(
status_info["error"] = ( f"Traceback for {status_info['name']}:\n{traceback.format_exc()}",
"This provider type is not tested and is assumed to be available."
) )
return status_info return status_info

View File

@@ -1,15 +1,10 @@
import asyncio import asyncio
import os
import re import re
from astrbot import logger from astrbot import logger
from astrbot.api import star from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_path
REACHABILITY_CHECK_TIMEOUT = 30.0
class ProviderCommands: class ProviderCommands:
@@ -34,121 +29,20 @@ class ProviderCommands:
) )
async def _test_provider_capability(self, provider): async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)""" """测试单个 provider 的可用性"""
meta = provider.meta() meta = provider.meta()
provider_capability_type = meta.provider_type provider_capability_type = meta.provider_type
try: try:
if provider_capability_type == ProviderType.CHAT_COMPLETION: result = await provider.test()
# 发送 "Ping" 测试对话 if result:
response = await asyncio.wait_for( return True, None, None
provider.text_chat(prompt="REPLY `PONG` ONLY"), err_code = "TEST_FAILED"
timeout=REACHABILITY_CHECK_TIMEOUT, err_reason = "Provider test returned False"
) self._log_reachability_failure(
if response is not None: provider, provider_capability_type, err_code, err_reason
return True, None, None )
err_code = "EMPTY_RESPONSE" return False, err_code, err_reason
err_reason = "Provider returned empty response"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.EMBEDDING:
# 测试 Embedding
embedding_result = await asyncio.wait_for(
provider.get_embedding("health_check"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if (
isinstance(embedding_result, list)
and embedding_result
and all(isinstance(x, (int, float)) for x in embedding_result)
):
return True, None, None
err_code = "INVALID_EMBEDDING"
err_reason = "Provider returned invalid embedding"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
# 测试 TTS
audio_result = await asyncio.wait_for(
provider.get_audio("你好"),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if isinstance(audio_result, str) and audio_result:
# 清理检测生成的临时音频文件,避免频繁检测时堆积
if os.path.isfile(audio_result):
try:
os.remove(audio_result)
except OSError as e:
logger.debug(
"Failed to cleanup TTS health check file %s: %s",
audio_result,
e,
)
return True, None, None
err_code = "INVALID_AUDIO"
err_reason = "Provider returned invalid audio"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
# 测试 STT
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
if not os.path.exists(sample_audio_path):
# 如果样本文件不存在,降级为检查是否实现了方法
return hasattr(provider, "get_text"), None, None
text_result = await asyncio.wait_for(
provider.get_text(sample_audio_path),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
if isinstance(text_result, str) and text_result:
return True, None, None
err_code = "INVALID_TEXT"
err_reason = "Provider returned invalid text"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
elif provider_capability_type == ProviderType.RERANK:
# 测试 Rerank
if isinstance(provider, RerankProvider):
await asyncio.wait_for(
provider.rerank("Apple", documents=["apple", "banana"]),
timeout=REACHABILITY_CHECK_TIMEOUT,
)
return True, None, None
err_code = "NOT_RERANK_PROVIDER"
err_reason = "Provider is not RerankProvider"
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
else:
# 其他类型暂时视为通过,或者回退到 get_models
if hasattr(provider, "get_models"):
await asyncio.wait_for(
provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT
)
return True, None, None
return True, None, None # 未知类型默认通过
except asyncio.TimeoutError:
err_code = "TIMEOUT"
err_reason = "Reachability check timed out"
except Exception as exc: except Exception as exc:
err_code = ( err_code = (
getattr(exc, "status_code", None) getattr(exc, "status_code", None)
@@ -159,10 +53,10 @@ class ProviderCommands:
if not err_code: if not err_code:
err_code = exc.__class__.__name__ err_code = exc.__class__.__name__
self._log_reachability_failure( self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason provider, provider_capability_type, err_code, err_reason
) )
return False, err_code, err_reason return False, err_code, err_reason
async def provider( async def provider(
self, self,