From 31d53edb9d061bfb956e5164b952c8f21babac7a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 1 Dec 2025 18:37:08 +0800 Subject: [PATCH] refactor: standardize provider test method implementation - Updated the `test` method in all provider classes to remove return values and raise exceptions for failure cases, enhancing clarity and consistency. - Adjusted related logic in the dashboard and command routes to align with the new `test` method behavior, simplifying error handling. --- astrbot/core/provider/provider.py | 70 +++++++++------------------ astrbot/dashboard/routes/config.py | 16 ++---- packages/astrbot/commands/provider.py | 22 ++------- 3 files changed, 33 insertions(+), 75 deletions(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index d8c2b140..2b5057e8 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -45,13 +45,13 @@ class AbstractProvider(abc.ABC): ) return meta - async def test(self) -> bool: + async def test(self): """test the provider is a - Returns: - bool: the provider is available + raises: + Exception: if the provider is not available """ - return True + ... class Provider(AbstractProvider): @@ -175,15 +175,11 @@ class Provider(AbstractProvider): return dicts - async def test(self, timeout: float = 45.0) -> bool: - try: - response = await asyncio.wait_for( - self.text_chat(prompt="REPLY `PONG` ONLY"), - timeout=timeout, - ) - return response is not None - except Exception: - return False + async def test(self, timeout: float = 45.0): + await asyncio.wait_for( + self.text_chat(prompt="REPLY `PONG` ONLY"), + timeout=timeout, + ) class STTProvider(AbstractProvider): @@ -197,19 +193,13 @@ class STTProvider(AbstractProvider): """获取音频的文本""" raise NotImplementedError - async def test(self) -> bool: - try: - sample_audio_path = os.path.join( - get_astrbot_path(), - "samples", - "stt_health_check.wav", - ) - if not os.path.exists(sample_audio_path): - return False - text_result = await self.get_text(sample_audio_path) - return isinstance(text_result, str) and bool(text_result) - except Exception: - return False + async def test(self): + sample_audio_path = os.path.join( + get_astrbot_path(), + "samples", + "stt_health_check.wav", + ) + await self.get_text(sample_audio_path) class TTSProvider(AbstractProvider): @@ -223,12 +213,8 @@ class TTSProvider(AbstractProvider): """获取文本的音频,返回音频文件路径""" raise NotImplementedError - async def test(self) -> bool: - try: - audio_result = await self.get_audio("hi") - return isinstance(audio_result, str) and bool(audio_result) - except Exception: - return False + async def test(self): + await self.get_audio("hi") class EmbeddingProvider(AbstractProvider): @@ -252,14 +238,8 @@ class EmbeddingProvider(AbstractProvider): """获取向量的维度""" ... - async def test(self) -> bool: - try: - embedding_result = await self.get_embedding("health_check") - return isinstance(embedding_result, list) and ( - not embedding_result or isinstance(embedding_result[0], float) - ) - except Exception: - return False + async def test(self): + await self.get_embedding("astrbot") async def get_embeddings_batch( self, @@ -345,9 +325,7 @@ class RerankProvider(AbstractProvider): """获取查询和文档的重排序分数""" ... - async def test(self) -> bool: - try: - await self.rerank("Apple", documents=["apple", "banana"]) - return True - except Exception: - return False + async def test(self): + result = await self.rerank("Apple", documents=["apple", "banana"]) + if not result: + raise Exception("Rerank provider test failed, no results returned") diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 6514032d..1089d8f8 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -354,17 +354,11 @@ class ConfigRoute(Route): ) try: - result = await provider.test() - if result: - status_info["status"] = "available" - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", - ) - else: - status_info["error"] = "Provider test returned False." - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.", - ) + await provider.test() + status_info["status"] = "available" + logger.info( + f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", + ) except Exception as e: error_message = str(e) status_info["error"] = error_message diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py index d75049c0..ce8f3183 100644 --- a/packages/astrbot/commands/provider.py +++ b/packages/astrbot/commands/provider.py @@ -34,25 +34,11 @@ class ProviderCommands: provider_capability_type = meta.provider_type try: - result = await provider.test() - if result: - return True, None, None + await provider.test() + return True, None, None + except Exception as e: err_code = "TEST_FAILED" - err_reason = "Provider test returned False" - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason - except Exception as exc: - err_code = ( - getattr(exc, "status_code", None) - or getattr(exc, "code", None) - or getattr(exc, "error_code", None) - ) - err_reason = str(exc) - if not err_code: - err_code = exc.__class__.__name__ - + err_reason = str(e) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason )