fix:对非文本生成类供应商暂时跳过测试

This commit is contained in:
Ruochen
2025-07-08 16:32:39 +08:00
parent a35f36eeaf
commit c44f085b47

View File

@@ -1,6 +1,7 @@
import typing import typing
import traceback import traceback
from .route import Route, Response, RouteContext from .route import Route, Response, RouteContext
from astrbot.core.provider.entities import ProviderType
from quart import request from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -187,15 +188,12 @@ class ConfigRoute(Route):
"""辅助函数:测试单个 provider 的可用性""" """辅助函数:测试单个 provider 的可用性"""
meta = provider.meta() meta = provider.meta()
provider_name = provider.provider_config.get("id", "Unknown Provider") provider_name = provider.provider_config.get("id", "Unknown Provider")
logger.debug(f"Got provider meta: {meta}") provider_capability_type = meta.provider_type
if not provider_name and meta:
provider_name = meta.id
elif not provider_name:
provider_name = "Unknown Provider"
status_info = { status_info = {
"id": getattr(meta, "id", "Unknown ID"), "id": getattr(meta, "id", "Unknown ID"),
"model": getattr(meta, "model", "Unknown Model"), "model": getattr(meta, "model", "Unknown Model"),
"type": getattr(meta, "type", "Unknown Type"), "type": provider_capability_type.value,
"name": provider_name, "name": provider_name,
"status": "unavailable", # 默认为不可用 "status": "unavailable", # 默认为不可用
"error": None, "error": None,
@@ -203,13 +201,19 @@ class ConfigRoute(Route):
logger.debug( logger.debug(
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})" f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
) )
if provider_capability_type != ProviderType.CHAT_COMPLETION:
logger.debug(f"Provider {provider_name} is not a Chat Completion provider. Marking as available without test. Meta: {meta}")
status_info["status"] = "available"
status_info["error"] = "This provider type is not tested and is assumed to be available."
return status_info
try: try:
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
response = await asyncio.wait_for( response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
) )
logger.debug(f"Received response from {status_info['name']}: {response}") logger.debug(f"Received response from {status_info['name']}: {response}")
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
if response is not None: if response is not None:
status_info["status"] = "available" status_info["status"] = "available"
response_text_snippet = "" response_text_snippet = ""
@@ -232,30 +236,18 @@ class ConfigRoute(Route):
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
) )
else: else:
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None status_info["error"] = "Test call returned None, but expected an LLMResponse object."
status_info["error"] = ( logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.")
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
status_info["error"] = ( status_info["error"] = "Connection timed out after 45 seconds during test call."
"Connection timed out after 45 seconds during test call." logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.")
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
)
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
status_info["error"] = error_message status_info["error"] = error_message
logger.warning( logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}")
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}" logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}")
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
)
return status_info return status_info
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error): def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
@@ -273,14 +265,12 @@ class ConfigRoute(Route):
logger.info(f"API call: /config/provider/check_one id={provider_id}") logger.info(f"API call: /config/provider/check_one id={provider_id}")
try: try:
all_providers = self.core_lifecycle.star_context.get_all_providers() prov_mgr = self.core_lifecycle.provider_manager
# replace manual loop with next(filter(...)) target = prov_mgr.inst_map.get(provider_id)
target = next(
(p for p in all_providers if p.provider_config.get("id") == provider_id),
None
)
if not target: if not target:
return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning) logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.")
return Response().error(f"Provider with id '{provider_id}' not found").__dict__
result = await self._test_single_provider(target) result = await self._test_single_provider(target)
return Response().ok(result).__dict__ return Response().ok(result).__dict__