refactor(bailian_rerank): 优化代码质量和错误处理

- 移除未使用的 os 导入
- 简化 API Key 验证逻辑
- 优化 top_n 参数处理,优先使用传入值
- 改进错误处理,使用 RuntimeError 替代通用 Exception
- 添加异常链保持原始错误上下文
This commit is contained in:
piexian
2025-11-21 04:07:45 +08:00
parent 2ada1deb9a
commit 234ce93dc1

View File

@@ -1,5 +1,3 @@
import os
import aiohttp
from astrbot import logger
@@ -13,7 +11,7 @@ from ..register import register_provider_adapter
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
)
class BailianRerankProvider(RerankProvider):
"""阿里云百炼文本重排序适配器"""
"""阿里云百炼文本重排序适配器."""
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
@@ -23,19 +21,10 @@ class BailianRerankProvider(RerankProvider):
# API配置
self.api_key = provider_config.get("rerank_api_key", "")
if not self.api_key:
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
if not self.api_key:
raise ValueError(
"阿里云百炼 API Key 不能为空,请在配置中设置 rerank_api_key 或设置环境变量 DASHSCOPE_API_KEY"
)
raise ValueError("阿里云百炼 API Key 不能为空。")
self.model = provider_config.get("rerank_model", "qwen3-rerank")
self.timeout = provider_config.get("timeout", 30)
# 自动读取知识库配置的 kb_final_top_k如果没有则使用配置中的 top_n
self.default_top_n = provider_settings.get(
"kb_final_top_k"
) or provider_config.get("top_n", 5)
self.return_documents = provider_config.get("return_documents", False)
self.instruct = provider_config.get("instruct", "")
@@ -71,7 +60,7 @@ class BailianRerankProvider(RerankProvider):
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回前N个结果如果为None则使用配置中的默认值
top_n: 返回前N个结果如果为None则返回所有重排序结果
Returns:
重排序结果列表
@@ -91,9 +80,6 @@ class BailianRerankProvider(RerankProvider):
)
documents = documents[:500]
# 优先使用传入的top_n参数来自知识库配置如果没有才使用默认配置
final_top_n = top_n if top_n is not None else self.default_top_n
try:
# 构建请求载荷
payload = {
@@ -103,8 +89,8 @@ class BailianRerankProvider(RerankProvider):
# 添加可选参数
parameters = {}
if final_top_n is not None:
parameters["top_n"] = final_top_n
if top_n is not None and top_n > 0:
parameters["top_n"] = top_n
if self.return_documents:
parameters["return_documents"] = True
if self.instruct and self.model == "qwen3-rerank":
@@ -125,9 +111,10 @@ class BailianRerankProvider(RerankProvider):
# 检查响应状态
if "code" in response_data and response_data["code"] != "200":
error_msg = response_data.get("message", "未知错误")
raise Exception(
api_error_msg = (
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
)
raise RuntimeError(api_error_msg)
# 解析结果
output = response_data.get("output", {})
@@ -156,14 +143,16 @@ class BailianRerankProvider(RerankProvider):
return rerank_results
except aiohttp.ClientError as e:
error_msg = f"网络请求失败: {e}"
logger.error(f"百炼 Rerank 网络请求失败: {e}")
raise Exception(f"网络请求失败: {e}")
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = f"重排序失败: {e}"
logger.error(f"百炼 Rerank 处理失败: {e}")
raise Exception(f"重排序失败: {e}")
raise RuntimeError(error_msg) from e
async def terminate(self) -> None:
"""关闭HTTP客户端会话"""
"""关闭HTTP客户端会话."""
if self.client:
logger.info("关闭 百炼 Rerank 客户端会话")
try: