feat: add Xinference rerank provider (#3162)

* feat:add Xinference rerank provider

* feat:add default rerank_api_key option for Xinference provider

* style: format code

* fix: refactor XinferenceRerankProvider initialization for better error handling

* fix: update XinferenceRerankProvider to use async client methods for initialization and reranking

* feat: add launch_model_if_not_running option to XinferenceRerankProvider for better control over model initialization

* chore: remove unused asyncio import from xinference_rerank_source.py
This commit is contained in:
RC-CHN
2025-10-28 18:23:55 +08:00
committed by GitHub
parent 3d88827a95
commit 90a65c35c1
5 changed files with 132 additions and 1 deletions

View File

@@ -1262,6 +1262,18 @@ CONFIG_METADATA_2 = {
"rerank_model": "BAAI/bge-reranker-base",
"timeout": 20,
},
"Xinference Rerank": {
"id": "xinference_rerank",
"type": "xinference_rerank",
"provider": "xinference",
"provider_type": "rerank",
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "http://127.0.0.1:9997",
"rerank_model": "BAAI/bge-reranker-base",
"timeout": 20,
"launch_model_if_not_running": False,
},
},
"items": {
"rerank_api_base": {
@@ -1278,6 +1290,11 @@ CONFIG_METADATA_2 = {
"description": "重排序模型名称",
"type": "string",
},
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
},
"modalities": {
"description": "模型能力",
"type": "list",

View File

@@ -311,6 +311,10 @@ class ProviderManager:
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
case "xinference_rerank":
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"

View File

@@ -0,0 +1,108 @@
from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client,
)
from astrbot import logger
from ..provider import RerankProvider
from ..register import register_provider_adapter
from ..entities import ProviderType, RerankResult
@register_provider_adapter(
"xinference_rerank",
"Xinference Rerank 适配器",
provider_type=ProviderType.RERANK,
)
class XinferenceRerankProvider(RerankProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
self.base_url = self.base_url.rstrip("/")
self.timeout = provider_config.get("timeout", 20)
self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
self.api_key = provider_config.get("rerank_api_key")
self.launch_model_if_not_running = provider_config.get(
"launch_model_if_not_running", False
)
self.client = None
self.model = None
self.model_uid = None
async def initialize(self):
if self.api_key:
logger.info("Xinference Rerank: Using API key for authentication.")
self.client = Client(self.base_url, api_key=self.api_key)
else:
logger.info("Xinference Rerank: No API key provided.")
self.client = Client(self.base_url)
try:
running_models = await self.client.list_models()
for uid, model_spec in running_models.items():
if model_spec.get("model_name") == self.model_name:
logger.info(
f"Model '{self.model_name}' is already running with UID: {uid}"
)
self.model_uid = uid
break
if self.model_uid is None:
if self.launch_model_if_not_running:
logger.info(f"Launching {self.model_name} model...")
self.model_uid = await self.client.launch_model(
model_name=self.model_name, model_type="rerank"
)
logger.info("Model launched.")
else:
logger.warning(
f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
)
return
if self.model_uid:
self.model = await self.client.get_model(self.model_uid)
except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}")
logger.debug(
f"Xinference initialization failed with exception: {e}", exc_info=True
)
self.model = None
async def rerank(
self, query: str, documents: list[str], top_n: int | None = None
) -> list[RerankResult]:
if not self.model:
logger.error("Xinference rerank model is not initialized.")
return []
try:
response = await self.model.rerank(documents, query, top_n)
results = response.get("results", [])
logger.debug(f"Rerank API response: {response}")
if not results:
logger.warning(
f"Rerank API returned an empty list. Original response: {response}"
)
return [
RerankResult(
index=result["index"],
relevance_score=result["relevance_score"],
)
for result in results
]
except Exception as e:
logger.error(f"Xinference rerank failed: {e}")
logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True)
return []
async def terminate(self) -> None:
"""关闭客户端会话"""
if self.client:
logger.info("Closing Xinference rerank client...")
try:
await self.client.close()
except Exception as e:
logger.error(f"Failed to close Xinference client: {e}", exc_info=True)

View File

@@ -55,6 +55,7 @@ dependencies = [
"rank-bm25>=0.2.2",
"jieba>=0.42.1",
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
"xinference-client",
]
[project.scripts]

View File

@@ -49,3 +49,4 @@ aiofiles
rank-bm25
jieba
markitdown-no-magika[docx,xls,xlsx]
xinference-client