refactor(bailian_rerank): 修复误删除并优化top_n参数处理

- 移除不合理的知识库配置读取逻辑
- 添加os模块导入(用于读取环境变量)
- 抽取辅助函数:_build_payload()、_parse_results()、_log_usage()
- 添加自定义异常类:BailianRerankError、BailianAPIError、BailianNetworkError
- 使用.get()安全访问API响应字段,避免KeyError
- 使用raise ... from e保持异常链
This commit is contained in:
piexian
2025-11-21 05:34:18 +08:00
parent 234ce93dc1
commit 2e608cdc09

View File

@@ -1,3 +1,5 @@
import os
import aiohttp
from astrbot import logger
@@ -7,6 +9,24 @@ from ..provider import RerankProvider
from ..register import register_provider_adapter
class BailianRerankError(Exception):
"""百炼重排序服务异常基类"""
pass
class BailianAPIError(BailianRerankError):
"""百炼API返回错误"""
pass
class BailianNetworkError(BailianRerankError):
"""百炼网络请求错误"""
pass
@register_provider_adapter(
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
)
@@ -19,7 +39,9 @@ class BailianRerankProvider(RerankProvider):
self.provider_settings = provider_settings
# API配置
self.api_key = provider_config.get("rerank_api_key", "")
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
"DASHSCOPE_API_KEY", ""
)
if not self.api_key:
raise ValueError("阿里云百炼 API Key 不能为空。")
@@ -48,6 +70,96 @@ class BailianRerankProvider(RerankProvider):
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
def _build_payload(
self, query: str, documents: list[str], top_n: int | None
) -> dict:
"""构建请求载荷
Args:
query: 查询文本
documents: 文档列表
top_n: 返回前N个结果如果为None则返回所有结果
Returns:
请求载荷字典
"""
base = {"model": self.model, "input": {"query": query, "documents": documents}}
params = {
k: v
for k, v in [
("top_n", top_n if top_n is not None and top_n > 0 else None),
("return_documents", True if self.return_documents else None),
(
"instruct",
self.instruct
if self.instruct and self.model == "qwen3-rerank"
else None,
),
]
if v is not None
}
if params:
base["parameters"] = params
return base
def _parse_results(self, data: dict) -> list[RerankResult]:
"""解析API响应结果
Args:
data: API响应数据
Returns:
重排序结果列表
Raises:
BailianAPIError: API返回错误
KeyError: 结果缺少必要字段
"""
# 检查响应状态
if data.get("code", "200") != "200":
raise BailianAPIError(
f"百炼 API 错误: {data.get('code')} {data.get('message', '')}"
)
results = data.get("output", {}).get("results", [])
if not results:
logger.warning(f"百炼 Rerank 返回空结果: {data}")
return []
# 转换为RerankResult对象使用.get()避免KeyError
rerank_results = []
for idx, result in enumerate(results):
try:
index = result.get("index", idx)
relevance_score = result.get("relevance_score", 0.0)
if relevance_score is None:
logger.warning(f"结果 {idx} 缺少 relevance_score使用默认值 0.0")
relevance_score = 0.0
rerank_result = RerankResult(
index=index, relevance_score=relevance_score
)
rerank_results.append(rerank_result)
except Exception as e:
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
continue
return rerank_results
def _log_usage(self, data: dict) -> None:
"""记录使用量信息
Args:
data: API响应数据
"""
tokens = data.get("usage", {}).get("total_tokens", 0)
if tokens > 0:
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
async def rerank(
self,
query: str,
@@ -60,7 +172,7 @@ class BailianRerankProvider(RerankProvider):
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回前N个结果如果为None则返回所有重排序结果
top_n: 返回前N个结果如果为None则使用配置中的默认值
Returns:
重排序结果列表
@@ -81,23 +193,8 @@ class BailianRerankProvider(RerankProvider):
documents = documents[:500]
try:
# 构建请求载荷
payload = {
"model": self.model,
"input": {"query": query, "documents": documents},
}
# 添加可选参数
parameters = {}
if top_n is not None and top_n > 0:
parameters["top_n"] = top_n
if self.return_documents:
parameters["return_documents"] = True
if self.instruct and self.model == "qwen3-rerank":
parameters["instruct"] = self.instruct
if parameters:
payload["parameters"] = parameters
# 构建请求载荷如果top_n为None则返回所有重排序结果
payload = self._build_payload(query, documents, top_n)
logger.debug(
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
@@ -108,48 +205,24 @@ class BailianRerankProvider(RerankProvider):
response.raise_for_status()
response_data = await response.json()
# 检查响应状态
if "code" in response_data and response_data["code"] != "200":
error_msg = response_data.get("message", "未知错误")
api_error_msg = (
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
)
raise RuntimeError(api_error_msg)
# 解析结果并记录使用量
results = self._parse_results(response_data)
self._log_usage(response_data)
# 解析结果
output = response_data.get("output", {})
results = output.get("results", [])
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
if not results:
logger.warning(f"百炼 Rerank 返回空结果: {response_data}")
return []
# 转换为RerankResult对象
rerank_results = []
for result in results:
rerank_result = RerankResult(
index=result["index"], relevance_score=result["relevance_score"]
)
rerank_results.append(rerank_result)
logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果")
# 记录使用量信息
usage = response_data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}")
return rerank_results
return results
except aiohttp.ClientError as e:
error_msg = f"网络请求失败: {e}"
logger.error(f"百炼 Rerank 网络请求失败: {e}")
raise RuntimeError(error_msg) from e
raise BailianNetworkError(error_msg) from e
except BailianRerankError:
raise
except Exception as e:
error_msg = f"重排序失败: {e}"
logger.error(f"百炼 Rerank 处理失败: {e}")
raise RuntimeError(error_msg) from e
raise BailianRerankError(error_msg) from e
async def terminate(self) -> None:
"""关闭HTTP客户端会话."""