feat: add Xinference rerank provider (#3162)
* feat:add Xinference rerank provider * feat:add default rerank_api_key option for Xinference provider * style: format code * fix: refactor XinferenceRerankProvider initialization for better error handling * fix: update XinferenceRerankProvider to use async client methods for initialization and reranking * feat: add launch_model_if_not_running option to XinferenceRerankProvider for better control over model initialization * chore: remove unused asyncio import from xinference_rerank_source.py
This commit is contained in:
@@ -1262,6 +1262,18 @@ CONFIG_METADATA_2 = {
|
|||||||
"rerank_model": "BAAI/bge-reranker-base",
|
"rerank_model": "BAAI/bge-reranker-base",
|
||||||
"timeout": 20,
|
"timeout": 20,
|
||||||
},
|
},
|
||||||
|
"Xinference Rerank": {
|
||||||
|
"id": "xinference_rerank",
|
||||||
|
"type": "xinference_rerank",
|
||||||
|
"provider": "xinference",
|
||||||
|
"provider_type": "rerank",
|
||||||
|
"enable": True,
|
||||||
|
"rerank_api_key": "",
|
||||||
|
"rerank_api_base": "http://127.0.0.1:9997",
|
||||||
|
"rerank_model": "BAAI/bge-reranker-base",
|
||||||
|
"timeout": 20,
|
||||||
|
"launch_model_if_not_running": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
"rerank_api_base": {
|
"rerank_api_base": {
|
||||||
@@ -1278,6 +1290,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "重排序模型名称",
|
"description": "重排序模型名称",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
|
"launch_model_if_not_running": {
|
||||||
|
"description": "模型未运行时自动启动",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
|
||||||
|
},
|
||||||
"modalities": {
|
"modalities": {
|
||||||
"description": "模型能力",
|
"description": "模型能力",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
|
|||||||
@@ -311,6 +311,10 @@ class ProviderManager:
|
|||||||
from .sources.vllm_rerank_source import (
|
from .sources.vllm_rerank_source import (
|
||||||
VLLMRerankProvider as VLLMRerankProvider,
|
VLLMRerankProvider as VLLMRerankProvider,
|
||||||
)
|
)
|
||||||
|
case "xinference_rerank":
|
||||||
|
from .sources.xinference_rerank_source import (
|
||||||
|
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||||
|
)
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.critical(
|
logger.critical(
|
||||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||||
|
|||||||
108
astrbot/core/provider/sources/xinference_rerank_source.py
Normal file
108
astrbot/core/provider/sources/xinference_rerank_source.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
from xinference_client.client.restful.async_restful_client import (
|
||||||
|
AsyncClient as Client,
|
||||||
|
)
|
||||||
|
from astrbot import logger
|
||||||
|
from ..provider import RerankProvider
|
||||||
|
from ..register import register_provider_adapter
|
||||||
|
from ..entities import ProviderType, RerankResult
|
||||||
|
|
||||||
|
|
||||||
|
@register_provider_adapter(
|
||||||
|
"xinference_rerank",
|
||||||
|
"Xinference Rerank 适配器",
|
||||||
|
provider_type=ProviderType.RERANK,
|
||||||
|
)
|
||||||
|
class XinferenceRerankProvider(RerankProvider):
|
||||||
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||||
|
super().__init__(provider_config, provider_settings)
|
||||||
|
self.provider_config = provider_config
|
||||||
|
self.provider_settings = provider_settings
|
||||||
|
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
|
||||||
|
self.base_url = self.base_url.rstrip("/")
|
||||||
|
self.timeout = provider_config.get("timeout", 20)
|
||||||
|
self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
|
||||||
|
self.api_key = provider_config.get("rerank_api_key")
|
||||||
|
self.launch_model_if_not_running = provider_config.get(
|
||||||
|
"launch_model_if_not_running", False
|
||||||
|
)
|
||||||
|
self.client = None
|
||||||
|
self.model = None
|
||||||
|
self.model_uid = None
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
if self.api_key:
|
||||||
|
logger.info("Xinference Rerank: Using API key for authentication.")
|
||||||
|
self.client = Client(self.base_url, api_key=self.api_key)
|
||||||
|
else:
|
||||||
|
logger.info("Xinference Rerank: No API key provided.")
|
||||||
|
self.client = Client(self.base_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
running_models = await self.client.list_models()
|
||||||
|
for uid, model_spec in running_models.items():
|
||||||
|
if model_spec.get("model_name") == self.model_name:
|
||||||
|
logger.info(
|
||||||
|
f"Model '{self.model_name}' is already running with UID: {uid}"
|
||||||
|
)
|
||||||
|
self.model_uid = uid
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.model_uid is None:
|
||||||
|
if self.launch_model_if_not_running:
|
||||||
|
logger.info(f"Launching {self.model_name} model...")
|
||||||
|
self.model_uid = await self.client.launch_model(
|
||||||
|
model_name=self.model_name, model_type="rerank"
|
||||||
|
)
|
||||||
|
logger.info("Model launched.")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.model_uid:
|
||||||
|
self.model = await self.client.get_model(self.model_uid)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Xinference model: {e}")
|
||||||
|
logger.debug(
|
||||||
|
f"Xinference initialization failed with exception: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self, query: str, documents: list[str], top_n: int | None = None
|
||||||
|
) -> list[RerankResult]:
|
||||||
|
if not self.model:
|
||||||
|
logger.error("Xinference rerank model is not initialized.")
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
response = await self.model.rerank(documents, query, top_n)
|
||||||
|
results = response.get("results", [])
|
||||||
|
logger.debug(f"Rerank API response: {response}")
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
logger.warning(
|
||||||
|
f"Rerank API returned an empty list. Original response: {response}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
RerankResult(
|
||||||
|
index=result["index"],
|
||||||
|
relevance_score=result["relevance_score"],
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Xinference rerank failed: {e}")
|
||||||
|
logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def terminate(self) -> None:
|
||||||
|
"""关闭客户端会话"""
|
||||||
|
if self.client:
|
||||||
|
logger.info("Closing Xinference rerank client...")
|
||||||
|
try:
|
||||||
|
await self.client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
|
||||||
@@ -55,6 +55,7 @@ dependencies = [
|
|||||||
"rank-bm25>=0.2.2",
|
"rank-bm25>=0.2.2",
|
||||||
"jieba>=0.42.1",
|
"jieba>=0.42.1",
|
||||||
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
||||||
|
"xinference-client",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -48,4 +48,5 @@ pypdf
|
|||||||
aiofiles
|
aiofiles
|
||||||
rank-bm25
|
rank-bm25
|
||||||
jieba
|
jieba
|
||||||
markitdown-no-magika[docx,xls,xlsx]
|
markitdown-no-magika[docx,xls,xlsx]
|
||||||
|
xinference-client
|
||||||
Reference in New Issue
Block a user