116 lines
3.1 KiB
Python
116 lines
3.1 KiB
Python
"""知识库上下文注入器
|
|
|
|
负责检索相关知识并格式化为 LLM 可用的上下文文本
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
|
|
from astrbot.core.knowledge_base.database import KBDatabase
|
|
from astrbot.core.knowledge_base.retrieval.manager import (
|
|
RetrievalManager,
|
|
RetrievalResult,
|
|
)
|
|
|
|
|
|
class KnowledgeBaseInjector:
|
|
"""知识库上下文注入器
|
|
|
|
职责:
|
|
- 检索相关知识
|
|
- 格式化为上下文文本
|
|
- 注入到 LLM Prompt
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kb_db: KBDatabase,
|
|
retrieval_manager: RetrievalManager,
|
|
):
|
|
"""初始化知识库上下文注入器
|
|
|
|
Args:
|
|
kb_db: 知识库数据库实例
|
|
retrieval_manager: 检索管理器实例
|
|
"""
|
|
self.kb_db = kb_db
|
|
self.retrieval_manager = retrieval_manager
|
|
|
|
async def retrieve_and_inject(
|
|
self,
|
|
unified_msg_origin: str,
|
|
query: str,
|
|
top_k: int = 5,
|
|
) -> Optional[dict]:
|
|
"""检索并注入知识库上下文
|
|
|
|
Args:
|
|
unified_msg_origin: 统一消息来源 ID (会话 ID)
|
|
query: 用户查询
|
|
top_k: 返回结果数量
|
|
|
|
Returns:
|
|
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
|
|
{
|
|
"context_text": str, # 格式化的上下文文本
|
|
"results": List[dict], # 原始检索结果列表
|
|
}
|
|
"""
|
|
# 1. 获取会话关联的知识库
|
|
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
|
|
|
|
if not kb_ids:
|
|
return None
|
|
|
|
# 2. 检索知识
|
|
results = await self.retrieval_manager.retrieve(
|
|
query=query,
|
|
kb_ids=kb_ids,
|
|
top_m_final=top_k,
|
|
)
|
|
|
|
if not results:
|
|
return None
|
|
|
|
# 3. 格式化上下文
|
|
context_text = self._format_context(results)
|
|
|
|
# 4. 转换结果为字典格式
|
|
results_dict = [
|
|
{
|
|
"chunk_id": r.chunk_id,
|
|
"doc_id": r.doc_id,
|
|
"kb_id": r.kb_id,
|
|
"kb_name": r.kb_name,
|
|
"doc_name": r.doc_name,
|
|
"chunk_index": r.metadata.get("chunk_index", 0),
|
|
"content": r.content,
|
|
"score": r.score,
|
|
}
|
|
for r in results
|
|
]
|
|
|
|
return {
|
|
"context_text": context_text,
|
|
"results": results_dict,
|
|
}
|
|
|
|
def _format_context(self, results: List[RetrievalResult]) -> str:
|
|
"""格式化知识上下文
|
|
|
|
Args:
|
|
results: 检索结果列表
|
|
|
|
Returns:
|
|
str: 格式化的上下文文本
|
|
"""
|
|
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
|
|
|
for i, result in enumerate(results, 1):
|
|
lines.append(f"【知识 {i}】")
|
|
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
|
lines.append(f"内容: {result.content}")
|
|
lines.append(f"相关度: {result.score:.2f}")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|