diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 98e8fab8..36401b08 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,8 @@ import abc from typing import List from typing import TypedDict, AsyncGenerator from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entities import LLMResponse, ToolCallsResult +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType +from astrbot.core.provider.register import provider_cls_map from dataclasses import dataclass @@ -22,6 +23,7 @@ class ProviderMeta: id: str model: str type: str + provider_type: ProviderType class AbstractProvider(abc.ABC): @@ -40,10 +42,14 @@ class AbstractProvider(abc.ABC): def meta(self) -> ProviderMeta: """获取 Provider 的元数据""" + provider_type_name = self.provider_config["type"] + meta_data = provider_cls_map.get(provider_type_name) + provider_type = meta_data.provider_type if meta_data else None return ProviderMeta( id=self.provider_config["id"], model=self.get_model(), - type=self.provider_config["type"], + type=provider_type_name, + provider_type=provider_type, ) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 90c61ca6..7de720a3 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,9 +1,11 @@ -import numbers import typing import traceback +import os from .route import Route, Response, RouteContext +from astrbot.core.provider.entities import ProviderType from quart import request from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP +from astrbot.core.utils.astrbot_path import get_astrbot_path from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.platform.register import platform_registry @@ -188,32 +190,27 @@ class ConfigRoute(Route): """辅助函数:测试单个 provider 的可用性""" meta = provider.meta() provider_name = provider.provider_config.get("id", "Unknown Provider") - provider_type = provider.provider_config.get("provider_type", "Unknown Type") - logger.debug(f"Got provider meta: {meta}") - if not provider_name and meta: - provider_name = meta.id - elif not provider_name: - provider_name = "Unknown Provider" + provider_capability_type = meta.provider_type + status_info = { "id": getattr(meta, "id", "Unknown ID"), "model": getattr(meta, "model", "Unknown Model"), - "type": getattr(meta, "type", "Unknown Type"), + "type": provider_capability_type.value, "name": provider_name, - "provider_type": provider_type, "status": "unavailable", # 默认为不可用 "error": None, } logger.debug( f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})" ) - try: - if status_info["provider_type"] == "chat_completion": + + if provider_capability_type == ProviderType.CHAT_COMPLETION: + try: logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") response = await asyncio.wait_for( provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 ) logger.debug(f"Received response from {status_info['name']}: {response}") - # 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用 if response is not None: status_info["status"] = "available" response_text_snippet = "" @@ -236,49 +233,72 @@ class ConfigRoute(Route): f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" ) else: - # 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None - status_info["error"] = ( - "Test call returned None, but expected an LLMResponse object." - ) - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None." - ) - elif status_info["provider_type"] == "embedding": - logger.debug(f"Sending 'astrbot' to embedding provider: {status_info['name']}") - response = await asyncio.wait_for( - provider.get_embedding("astrbot"), timeout=45.0 - ) - logger.debug(f"Received response from {status_info['name']}: {response}") - # 若返回向量则认为该嵌入模型可用 - if response and isinstance(response, typing.Iterable) and all(isinstance(x, numbers.Number) for x in response): - status_info["status"] = "available" - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{str(response)[:10]}...'" - ) - else: - status_info["error"] = ( - f"Status checking for provider type '{status_info['type']}' not implemented." - ) - logger.warning( - f"Provider {status_info['name']}'s status checking not implemented yet" - ) + status_info["error"] = "Test call returned None, but expected an LLMResponse object." + logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.") + + except asyncio.TimeoutError: + status_info["error"] = "Connection timed out after 45 seconds during test call." + logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.") + except Exception as e: + error_message = str(e) + status_info["error"] = error_message + logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}") + logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}") + + elif provider_capability_type == ProviderType.EMBEDDING: + try: + # For embedding, we can call the get_embedding method with a short prompt. + embedding_result = await provider.get_embedding("health_check") + if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)): + status_info["status"] = "available" + else: + status_info["status"] = "unavailable" + status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}" + except Exception as e: + logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True) + status_info["status"] = "unavailable" + status_info["error"] = f"Embedding test failed: {str(e)}" + + elif provider_capability_type == ProviderType.TEXT_TO_SPEECH: + try: + # For TTS, we can call the get_audio method with a short prompt. + audio_result = await provider.get_audio("你好") + if isinstance(audio_result, str) and audio_result: + status_info["status"] = "available" + else: + status_info["status"] = "unavailable" + status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}" + except Exception as e: + logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True) + status_info["status"] = "unavailable" + status_info["error"] = f"TTS test failed: {str(e)}" + elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: + try: + logger.debug(f"Sending health check audio to provider: {status_info['name']}") + sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav") + if not os.path.exists(sample_audio_path): + status_info["status"] = "unavailable" + status_info["error"] = "STT test failed: sample audio file not found." + logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}") + else: + text_result = await provider.get_text(sample_audio_path) + if isinstance(text_result, str) and text_result: + status_info["status"] = "available" + snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result + logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'") + else: + status_info["status"] = "unavailable" + status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}" + logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}") + except Exception as e: + logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True) + status_info["status"] = "unavailable" + status_info["error"] = f"STT test failed: {str(e)}" + else: + logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}") + status_info["status"] = "available" + status_info["error"] = "This provider type is not tested and is assumed to be available." - except asyncio.TimeoutError: - status_info["error"] = ( - "Connection timed out after 45 seconds during test call." - ) - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) timed out." - ) - except Exception as e: - error_message = str(e) - status_info["error"] = error_message - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}" - ) - logger.debug( - f"Traceback for {status_info['name']}:\n{traceback.format_exc()}" - ) return status_info def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error): @@ -296,12 +316,12 @@ class ConfigRoute(Route): logger.info(f"API call: /config/provider/check_one id={provider_id}") try: - all_providers = self.core_lifecycle.star_context.get_all_providers() - all_providers += self.core_lifecycle.star_context.get_all_embedding_providers() - # replace manual loop with next(filter(...)) - target = next(filter(lambda p: p.provider_config.get("id") == provider_id, all_providers), None) + prov_mgr = self.core_lifecycle.provider_manager + target = prov_mgr.inst_map.get(provider_id) + if not target: - return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning) + logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.") + return Response().error(f"Provider with id '{provider_id}' not found").__dict__ result = await self._test_single_provider(target) return Response().ok(result).__dict__ diff --git a/samples/stt_health_check.wav b/samples/stt_health_check.wav new file mode 100644 index 00000000..2c6182c7 Binary files /dev/null and b/samples/stt_health_check.wav differ