refactor: standardize provider test method implementation
- Updated the `test` method in all provider classes to remove return values and raise exceptions for failure cases, enhancing clarity and consistency. - Adjusted related logic in the dashboard and command routes to align with the new `test` method behavior, simplifying error handling.
This commit is contained in:
@@ -45,13 +45,13 @@ class AbstractProvider(abc.ABC):
|
||||
)
|
||||
return meta
|
||||
|
||||
async def test(self) -> bool:
|
||||
async def test(self):
|
||||
"""test the provider is a
|
||||
|
||||
Returns:
|
||||
bool: the provider is available
|
||||
raises:
|
||||
Exception: if the provider is not available
|
||||
"""
|
||||
return True
|
||||
...
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
@@ -175,15 +175,11 @@ class Provider(AbstractProvider):
|
||||
|
||||
return dicts
|
||||
|
||||
async def test(self, timeout: float = 45.0) -> bool:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response is not None
|
||||
except Exception:
|
||||
return False
|
||||
async def test(self, timeout: float = 45.0):
|
||||
await asyncio.wait_for(
|
||||
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
@@ -197,19 +193,13 @@ class STTProvider(AbstractProvider):
|
||||
"""获取音频的文本"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
if not os.path.exists(sample_audio_path):
|
||||
return False
|
||||
text_result = await self.get_text(sample_audio_path)
|
||||
return isinstance(text_result, str) and bool(text_result)
|
||||
except Exception:
|
||||
return False
|
||||
async def test(self):
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
await self.get_text(sample_audio_path)
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
@@ -223,12 +213,8 @@ class TTSProvider(AbstractProvider):
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
audio_result = await self.get_audio("hi")
|
||||
return isinstance(audio_result, str) and bool(audio_result)
|
||||
except Exception:
|
||||
return False
|
||||
async def test(self):
|
||||
await self.get_audio("hi")
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
@@ -252,14 +238,8 @@ class EmbeddingProvider(AbstractProvider):
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
embedding_result = await self.get_embedding("health_check")
|
||||
return isinstance(embedding_result, list) and (
|
||||
not embedding_result or isinstance(embedding_result[0], float)
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
async def test(self):
|
||||
await self.get_embedding("astrbot")
|
||||
|
||||
async def get_embeddings_batch(
|
||||
self,
|
||||
@@ -345,9 +325,7 @@ class RerankProvider(AbstractProvider):
|
||||
"""获取查询和文档的重排序分数"""
|
||||
...
|
||||
|
||||
async def test(self) -> bool:
|
||||
try:
|
||||
await self.rerank("Apple", documents=["apple", "banana"])
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
async def test(self):
|
||||
result = await self.rerank("Apple", documents=["apple", "banana"])
|
||||
if not result:
|
||||
raise Exception("Rerank provider test failed, no results returned")
|
||||
|
||||
@@ -354,17 +354,11 @@ class ConfigRoute(Route):
|
||||
)
|
||||
|
||||
try:
|
||||
result = await provider.test()
|
||||
if result:
|
||||
status_info["status"] = "available"
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
|
||||
)
|
||||
else:
|
||||
status_info["error"] = "Provider test returned False."
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.",
|
||||
)
|
||||
await provider.test()
|
||||
status_info["status"] = "available"
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available.",
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
|
||||
@@ -34,25 +34,11 @@ class ProviderCommands:
|
||||
provider_capability_type = meta.provider_type
|
||||
|
||||
try:
|
||||
result = await provider.test()
|
||||
if result:
|
||||
return True, None, None
|
||||
await provider.test()
|
||||
return True, None, None
|
||||
except Exception as e:
|
||||
err_code = "TEST_FAILED"
|
||||
err_reason = "Provider test returned False"
|
||||
self._log_reachability_failure(
|
||||
provider, provider_capability_type, err_code, err_reason
|
||||
)
|
||||
return False, err_code, err_reason
|
||||
except Exception as exc:
|
||||
err_code = (
|
||||
getattr(exc, "status_code", None)
|
||||
or getattr(exc, "code", None)
|
||||
or getattr(exc, "error_code", None)
|
||||
)
|
||||
err_reason = str(exc)
|
||||
if not err_code:
|
||||
err_code = exc.__class__.__name__
|
||||
|
||||
err_reason = str(e)
|
||||
self._log_reachability_failure(
|
||||
provider, provider_capability_type, err_code, err_reason
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user