fix:对非文本生成类供应商暂时跳过测试
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import typing
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -187,15 +188,12 @@ class ConfigRoute(Route):
|
||||
"""辅助函数:测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
provider_name = provider.provider_config.get("id", "Unknown Provider")
|
||||
logger.debug(f"Got provider meta: {meta}")
|
||||
if not provider_name and meta:
|
||||
provider_name = meta.id
|
||||
elif not provider_name:
|
||||
provider_name = "Unknown Provider"
|
||||
provider_capability_type = meta.provider_type
|
||||
|
||||
status_info = {
|
||||
"id": getattr(meta, "id", "Unknown ID"),
|
||||
"model": getattr(meta, "model", "Unknown Model"),
|
||||
"type": getattr(meta, "type", "Unknown Type"),
|
||||
"type": provider_capability_type.value,
|
||||
"name": provider_name,
|
||||
"status": "unavailable", # 默认为不可用
|
||||
"error": None,
|
||||
@@ -203,13 +201,19 @@ class ConfigRoute(Route):
|
||||
logger.debug(
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
|
||||
)
|
||||
|
||||
if provider_capability_type != ProviderType.CHAT_COMPLETION:
|
||||
logger.debug(f"Provider {provider_name} is not a Chat Completion provider. Marking as available without test. Meta: {meta}")
|
||||
status_info["status"] = "available"
|
||||
status_info["error"] = "This provider type is not tested and is assumed to be available."
|
||||
return status_info
|
||||
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
@@ -232,30 +236,18 @@ class ConfigRoute(Route):
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
)
|
||||
else:
|
||||
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
)
|
||||
status_info["error"] = "Test call returned None, but expected an LLMResponse object."
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
|
||||
)
|
||||
status_info["error"] = "Connection timed out after 45 seconds during test call."
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.")
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}")
|
||||
logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}")
|
||||
|
||||
return status_info
|
||||
|
||||
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
|
||||
@@ -273,14 +265,12 @@ class ConfigRoute(Route):
|
||||
|
||||
logger.info(f"API call: /config/provider/check_one id={provider_id}")
|
||||
try:
|
||||
all_providers = self.core_lifecycle.star_context.get_all_providers()
|
||||
# replace manual loop with next(filter(...))
|
||||
target = next(
|
||||
(p for p in all_providers if p.provider_config.get("id") == provider_id),
|
||||
None
|
||||
)
|
||||
prov_mgr = self.core_lifecycle.provider_manager
|
||||
target = prov_mgr.inst_map.get(provider_id)
|
||||
|
||||
if not target:
|
||||
return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning)
|
||||
logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.")
|
||||
return Response().error(f"Provider with id '{provider_id}' not found").__dict__
|
||||
|
||||
result = await self._test_single_provider(target)
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
Reference in New Issue
Block a user