refactor(bailian_rerank): 修复误删除并优化top_n参数处理
- 移除不合理的知识库配置读取逻辑 - 添加os模块导入(用于读取环境变量) - 抽取辅助函数:_build_payload()、_parse_results()、_log_usage() - 添加自定义异常类:BailianRerankError、BailianAPIError、BailianNetworkError - 使用.get()安全访问API响应字段,避免KeyError - 使用raise ... from e保持异常链
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot import logger
|
||||
@@ -7,6 +9,24 @@ from ..provider import RerankProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
class BailianRerankError(Exception):
|
||||
"""百炼重排序服务异常基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BailianAPIError(BailianRerankError):
|
||||
"""百炼API返回错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BailianNetworkError(BailianRerankError):
|
||||
"""百炼网络请求错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
|
||||
)
|
||||
@@ -19,7 +39,9 @@ class BailianRerankProvider(RerankProvider):
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
# API配置
|
||||
self.api_key = provider_config.get("rerank_api_key", "")
|
||||
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
|
||||
"DASHSCOPE_API_KEY", ""
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError("阿里云百炼 API Key 不能为空。")
|
||||
|
||||
@@ -48,6 +70,96 @@ class BailianRerankProvider(RerankProvider):
|
||||
|
||||
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
|
||||
|
||||
def _build_payload(
|
||||
self, query: str, documents: list[str], top_n: int | None
|
||||
) -> dict:
|
||||
"""构建请求载荷
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 文档列表
|
||||
top_n: 返回前N个结果,如果为None则返回所有结果
|
||||
|
||||
Returns:
|
||||
请求载荷字典
|
||||
"""
|
||||
base = {"model": self.model, "input": {"query": query, "documents": documents}}
|
||||
|
||||
params = {
|
||||
k: v
|
||||
for k, v in [
|
||||
("top_n", top_n if top_n is not None and top_n > 0 else None),
|
||||
("return_documents", True if self.return_documents else None),
|
||||
(
|
||||
"instruct",
|
||||
self.instruct
|
||||
if self.instruct and self.model == "qwen3-rerank"
|
||||
else None,
|
||||
),
|
||||
]
|
||||
if v is not None
|
||||
}
|
||||
|
||||
if params:
|
||||
base["parameters"] = params
|
||||
|
||||
return base
|
||||
|
||||
def _parse_results(self, data: dict) -> list[RerankResult]:
|
||||
"""解析API响应结果
|
||||
|
||||
Args:
|
||||
data: API响应数据
|
||||
|
||||
Returns:
|
||||
重排序结果列表
|
||||
|
||||
Raises:
|
||||
BailianAPIError: API返回错误
|
||||
KeyError: 结果缺少必要字段
|
||||
"""
|
||||
# 检查响应状态
|
||||
if data.get("code", "200") != "200":
|
||||
raise BailianAPIError(
|
||||
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
|
||||
)
|
||||
|
||||
results = data.get("output", {}).get("results", [])
|
||||
if not results:
|
||||
logger.warning(f"百炼 Rerank 返回空结果: {data}")
|
||||
return []
|
||||
|
||||
# 转换为RerankResult对象,使用.get()避免KeyError
|
||||
rerank_results = []
|
||||
for idx, result in enumerate(results):
|
||||
try:
|
||||
index = result.get("index", idx)
|
||||
relevance_score = result.get("relevance_score", 0.0)
|
||||
|
||||
if relevance_score is None:
|
||||
logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
|
||||
relevance_score = 0.0
|
||||
|
||||
rerank_result = RerankResult(
|
||||
index=index, relevance_score=relevance_score
|
||||
)
|
||||
rerank_results.append(rerank_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
|
||||
continue
|
||||
|
||||
return rerank_results
|
||||
|
||||
def _log_usage(self, data: dict) -> None:
|
||||
"""记录使用量信息
|
||||
|
||||
Args:
|
||||
data: API响应数据
|
||||
"""
|
||||
tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
if tokens > 0:
|
||||
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
query: str,
|
||||
@@ -60,7 +172,7 @@ class BailianRerankProvider(RerankProvider):
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 待排序的文档列表
|
||||
top_n: 返回前N个结果,如果为None则返回所有重排序结果
|
||||
top_n: 返回前N个结果,如果为None则使用配置中的默认值
|
||||
|
||||
Returns:
|
||||
重排序结果列表
|
||||
@@ -81,23 +193,8 @@ class BailianRerankProvider(RerankProvider):
|
||||
documents = documents[:500]
|
||||
|
||||
try:
|
||||
# 构建请求载荷
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": {"query": query, "documents": documents},
|
||||
}
|
||||
|
||||
# 添加可选参数
|
||||
parameters = {}
|
||||
if top_n is not None and top_n > 0:
|
||||
parameters["top_n"] = top_n
|
||||
if self.return_documents:
|
||||
parameters["return_documents"] = True
|
||||
if self.instruct and self.model == "qwen3-rerank":
|
||||
parameters["instruct"] = self.instruct
|
||||
|
||||
if parameters:
|
||||
payload["parameters"] = parameters
|
||||
# 构建请求载荷,如果top_n为None则返回所有重排序结果
|
||||
payload = self._build_payload(query, documents, top_n)
|
||||
|
||||
logger.debug(
|
||||
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
|
||||
@@ -108,48 +205,24 @@ class BailianRerankProvider(RerankProvider):
|
||||
response.raise_for_status()
|
||||
response_data = await response.json()
|
||||
|
||||
# 检查响应状态
|
||||
if "code" in response_data and response_data["code"] != "200":
|
||||
error_msg = response_data.get("message", "未知错误")
|
||||
api_error_msg = (
|
||||
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
|
||||
)
|
||||
raise RuntimeError(api_error_msg)
|
||||
# 解析结果并记录使用量
|
||||
results = self._parse_results(response_data)
|
||||
self._log_usage(response_data)
|
||||
|
||||
# 解析结果
|
||||
output = response_data.get("output", {})
|
||||
results = output.get("results", [])
|
||||
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
|
||||
|
||||
if not results:
|
||||
logger.warning(f"百炼 Rerank 返回空结果: {response_data}")
|
||||
return []
|
||||
|
||||
# 转换为RerankResult对象
|
||||
rerank_results = []
|
||||
for result in results:
|
||||
rerank_result = RerankResult(
|
||||
index=result["index"], relevance_score=result["relevance_score"]
|
||||
)
|
||||
rerank_results.append(rerank_result)
|
||||
|
||||
logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果")
|
||||
|
||||
# 记录使用量信息
|
||||
usage = response_data.get("usage", {})
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
if total_tokens > 0:
|
||||
logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}")
|
||||
|
||||
return rerank_results
|
||||
return results
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
error_msg = f"网络请求失败: {e}"
|
||||
logger.error(f"百炼 Rerank 网络请求失败: {e}")
|
||||
raise RuntimeError(error_msg) from e
|
||||
raise BailianNetworkError(error_msg) from e
|
||||
except BailianRerankError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"重排序失败: {e}"
|
||||
logger.error(f"百炼 Rerank 处理失败: {e}")
|
||||
raise RuntimeError(error_msg) from e
|
||||
raise BailianRerankError(error_msg) from e
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""关闭HTTP客户端会话."""
|
||||
|
||||
Reference in New Issue
Block a user