refactor(bailian_rerank): 优化代码质量和错误处理
- 移除未使用的 os 导入 - 简化 API Key 验证逻辑 - 优化 top_n 参数处理,优先使用传入值 - 改进错误处理,使用 RuntimeError 替代通用 Exception - 添加异常链保持原始错误上下文
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot import logger
|
||||
@@ -13,7 +11,7 @@ from ..register import register_provider_adapter
|
||||
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
|
||||
)
|
||||
class BailianRerankProvider(RerankProvider):
|
||||
"""阿里云百炼文本重排序适配器"""
|
||||
"""阿里云百炼文本重排序适配器."""
|
||||
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
@@ -23,19 +21,10 @@ class BailianRerankProvider(RerankProvider):
|
||||
# API配置
|
||||
self.api_key = provider_config.get("rerank_api_key", "")
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"阿里云百炼 API Key 不能为空,请在配置中设置 rerank_api_key 或设置环境变量 DASHSCOPE_API_KEY"
|
||||
)
|
||||
raise ValueError("阿里云百炼 API Key 不能为空。")
|
||||
|
||||
self.model = provider_config.get("rerank_model", "qwen3-rerank")
|
||||
self.timeout = provider_config.get("timeout", 30)
|
||||
# 自动读取知识库配置的 kb_final_top_k,如果没有则使用配置中的 top_n
|
||||
self.default_top_n = provider_settings.get(
|
||||
"kb_final_top_k"
|
||||
) or provider_config.get("top_n", 5)
|
||||
self.return_documents = provider_config.get("return_documents", False)
|
||||
self.instruct = provider_config.get("instruct", "")
|
||||
|
||||
@@ -71,7 +60,7 @@ class BailianRerankProvider(RerankProvider):
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 待排序的文档列表
|
||||
top_n: 返回前N个结果,如果为None则使用配置中的默认值
|
||||
top_n: 返回前N个结果,如果为None则返回所有重排序结果
|
||||
|
||||
Returns:
|
||||
重排序结果列表
|
||||
@@ -91,9 +80,6 @@ class BailianRerankProvider(RerankProvider):
|
||||
)
|
||||
documents = documents[:500]
|
||||
|
||||
# 优先使用传入的top_n参数(来自知识库配置),如果没有才使用默认配置
|
||||
final_top_n = top_n if top_n is not None else self.default_top_n
|
||||
|
||||
try:
|
||||
# 构建请求载荷
|
||||
payload = {
|
||||
@@ -103,8 +89,8 @@ class BailianRerankProvider(RerankProvider):
|
||||
|
||||
# 添加可选参数
|
||||
parameters = {}
|
||||
if final_top_n is not None:
|
||||
parameters["top_n"] = final_top_n
|
||||
if top_n is not None and top_n > 0:
|
||||
parameters["top_n"] = top_n
|
||||
if self.return_documents:
|
||||
parameters["return_documents"] = True
|
||||
if self.instruct and self.model == "qwen3-rerank":
|
||||
@@ -125,9 +111,10 @@ class BailianRerankProvider(RerankProvider):
|
||||
# 检查响应状态
|
||||
if "code" in response_data and response_data["code"] != "200":
|
||||
error_msg = response_data.get("message", "未知错误")
|
||||
raise Exception(
|
||||
api_error_msg = (
|
||||
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
|
||||
)
|
||||
raise RuntimeError(api_error_msg)
|
||||
|
||||
# 解析结果
|
||||
output = response_data.get("output", {})
|
||||
@@ -156,14 +143,16 @@ class BailianRerankProvider(RerankProvider):
|
||||
return rerank_results
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
error_msg = f"网络请求失败: {e}"
|
||||
logger.error(f"百炼 Rerank 网络请求失败: {e}")
|
||||
raise Exception(f"网络请求失败: {e}")
|
||||
raise RuntimeError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"重排序失败: {e}"
|
||||
logger.error(f"百炼 Rerank 处理失败: {e}")
|
||||
raise Exception(f"重排序失败: {e}")
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""关闭HTTP客户端会话"""
|
||||
"""关闭HTTP客户端会话."""
|
||||
if self.client:
|
||||
logger.info("关闭 百炼 Rerank 客户端会话")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user