* feat: 添加从 URL 上传文档的功能,支持进度回调和错误处理 * feat: 添加从 URL 上传文档的前端 * chore: 添加 URL 上传功能的警告提示,确保用户配置正确 * feat: 添加内容清洗功能,支持从 URL 上传文档时的清洗设置和服务提供商选择 * feat: 更新内容清洗系统提示,增强信息提取规则;添加 URL 上传功能的测试版标识 * style: format code * perf: 优化上传设置,增强 URL 上传时的禁用逻辑和清洗提供商验证 * refactor:使用自带chunking模块 * refactor: 提取prompt到单独文件 * feat: 添加 Tavily API Key 配置对话框,增强网页搜索功能的配置体验 * fix: update URL hint and warning messages for clarity in knowledge base upload settings * fix: 修复设置tavily_key的热重载问题 --------- Co-authored-by: Soulter <905617992@qq.com>
331 lines
11 KiB
Python
331 lines
11 KiB
Python
import traceback
|
|
from pathlib import Path
|
|
|
|
from astrbot.core import logger
|
|
from astrbot.core.provider.manager import ProviderManager
|
|
|
|
# from .chunking.fixed_size import FixedSizeChunker
|
|
from .chunking.recursive import RecursiveCharacterChunker
|
|
from .kb_db_sqlite import KBSQLiteDatabase
|
|
from .kb_helper import KBHelper
|
|
from .models import KBDocument, KnowledgeBase
|
|
from .retrieval.manager import RetrievalManager, RetrievalResult
|
|
from .retrieval.rank_fusion import RankFusion
|
|
from .retrieval.sparse_retriever import SparseRetriever
|
|
|
|
FILES_PATH = "data/knowledge_base"
|
|
DB_PATH = Path(FILES_PATH) / "kb.db"
|
|
"""Knowledge Base storage root directory"""
|
|
CHUNKER = RecursiveCharacterChunker()
|
|
|
|
|
|
class KnowledgeBaseManager:
|
|
kb_db: KBSQLiteDatabase
|
|
retrieval_manager: RetrievalManager
|
|
|
|
def __init__(
|
|
self,
|
|
provider_manager: ProviderManager,
|
|
):
|
|
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
|
|
self.provider_manager = provider_manager
|
|
self._session_deleted_callback_registered = False
|
|
|
|
self.kb_insts: dict[str, KBHelper] = {}
|
|
|
|
async def initialize(self):
|
|
"""初始化知识库模块"""
|
|
try:
|
|
logger.info("正在初始化知识库模块...")
|
|
|
|
# 初始化数据库
|
|
await self._init_kb_database()
|
|
|
|
# 初始化检索管理器
|
|
sparse_retriever = SparseRetriever(self.kb_db)
|
|
rank_fusion = RankFusion(self.kb_db)
|
|
self.retrieval_manager = RetrievalManager(
|
|
sparse_retriever=sparse_retriever,
|
|
rank_fusion=rank_fusion,
|
|
kb_db=self.kb_db,
|
|
)
|
|
await self.load_kbs()
|
|
|
|
except ImportError as e:
|
|
logger.error(f"知识库模块导入失败: {e}")
|
|
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
|
except Exception as e:
|
|
logger.error(f"知识库模块初始化失败: {e}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
async def _init_kb_database(self):
|
|
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
|
|
await self.kb_db.initialize()
|
|
await self.kb_db.migrate_to_v1()
|
|
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
|
|
|
|
async def load_kbs(self):
|
|
"""加载所有知识库实例"""
|
|
kb_records = await self.kb_db.list_kbs()
|
|
for record in kb_records:
|
|
kb_helper = KBHelper(
|
|
kb_db=self.kb_db,
|
|
kb=record,
|
|
provider_manager=self.provider_manager,
|
|
kb_root_dir=FILES_PATH,
|
|
chunker=CHUNKER,
|
|
)
|
|
await kb_helper.initialize()
|
|
self.kb_insts[record.kb_id] = kb_helper
|
|
|
|
async def create_kb(
|
|
self,
|
|
kb_name: str,
|
|
description: str | None = None,
|
|
emoji: str | None = None,
|
|
embedding_provider_id: str | None = None,
|
|
rerank_provider_id: str | None = None,
|
|
chunk_size: int | None = None,
|
|
chunk_overlap: int | None = None,
|
|
top_k_dense: int | None = None,
|
|
top_k_sparse: int | None = None,
|
|
top_m_final: int | None = None,
|
|
) -> KBHelper:
|
|
"""创建新的知识库实例"""
|
|
kb = KnowledgeBase(
|
|
kb_name=kb_name,
|
|
description=description,
|
|
emoji=emoji or "📚",
|
|
embedding_provider_id=embedding_provider_id,
|
|
rerank_provider_id=rerank_provider_id,
|
|
chunk_size=chunk_size if chunk_size is not None else 512,
|
|
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
|
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
|
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
|
top_m_final=top_m_final if top_m_final is not None else 5,
|
|
)
|
|
async with self.kb_db.get_db() as session:
|
|
session.add(kb)
|
|
await session.commit()
|
|
await session.refresh(kb)
|
|
|
|
kb_helper = KBHelper(
|
|
kb_db=self.kb_db,
|
|
kb=kb,
|
|
provider_manager=self.provider_manager,
|
|
kb_root_dir=FILES_PATH,
|
|
chunker=CHUNKER,
|
|
)
|
|
await kb_helper.initialize()
|
|
self.kb_insts[kb.kb_id] = kb_helper
|
|
return kb_helper
|
|
|
|
async def get_kb(self, kb_id: str) -> KBHelper | None:
|
|
"""获取知识库实例"""
|
|
if kb_id in self.kb_insts:
|
|
return self.kb_insts[kb_id]
|
|
|
|
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
|
|
"""通过名称获取知识库实例"""
|
|
for kb_helper in self.kb_insts.values():
|
|
if kb_helper.kb.kb_name == kb_name:
|
|
return kb_helper
|
|
return None
|
|
|
|
async def delete_kb(self, kb_id: str) -> bool:
|
|
"""删除知识库实例"""
|
|
kb_helper = await self.get_kb(kb_id)
|
|
if not kb_helper:
|
|
return False
|
|
|
|
await kb_helper.delete_vec_db()
|
|
async with self.kb_db.get_db() as session:
|
|
await session.delete(kb_helper.kb)
|
|
await session.commit()
|
|
|
|
self.kb_insts.pop(kb_id, None)
|
|
return True
|
|
|
|
async def list_kbs(self) -> list[KnowledgeBase]:
|
|
"""列出所有知识库实例"""
|
|
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
|
|
return kbs
|
|
|
|
async def update_kb(
|
|
self,
|
|
kb_id: str,
|
|
kb_name: str,
|
|
description: str | None = None,
|
|
emoji: str | None = None,
|
|
embedding_provider_id: str | None = None,
|
|
rerank_provider_id: str | None = None,
|
|
chunk_size: int | None = None,
|
|
chunk_overlap: int | None = None,
|
|
top_k_dense: int | None = None,
|
|
top_k_sparse: int | None = None,
|
|
top_m_final: int | None = None,
|
|
) -> KBHelper | None:
|
|
"""更新知识库实例"""
|
|
kb_helper = await self.get_kb(kb_id)
|
|
if not kb_helper:
|
|
return None
|
|
|
|
kb = kb_helper.kb
|
|
if kb_name is not None:
|
|
kb.kb_name = kb_name
|
|
if description is not None:
|
|
kb.description = description
|
|
if emoji is not None:
|
|
kb.emoji = emoji
|
|
if embedding_provider_id is not None:
|
|
kb.embedding_provider_id = embedding_provider_id
|
|
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
|
|
if chunk_size is not None:
|
|
kb.chunk_size = chunk_size
|
|
if chunk_overlap is not None:
|
|
kb.chunk_overlap = chunk_overlap
|
|
if top_k_dense is not None:
|
|
kb.top_k_dense = top_k_dense
|
|
if top_k_sparse is not None:
|
|
kb.top_k_sparse = top_k_sparse
|
|
if top_m_final is not None:
|
|
kb.top_m_final = top_m_final
|
|
async with self.kb_db.get_db() as session:
|
|
session.add(kb)
|
|
await session.commit()
|
|
await session.refresh(kb)
|
|
|
|
return kb_helper
|
|
|
|
async def retrieve(
|
|
self,
|
|
query: str,
|
|
kb_names: list[str],
|
|
top_k_fusion: int = 20,
|
|
top_m_final: int = 5,
|
|
) -> dict | None:
|
|
"""从指定知识库中检索相关内容"""
|
|
kb_ids = []
|
|
kb_id_helper_map = {}
|
|
for kb_name in kb_names:
|
|
if kb_helper := await self.get_kb_by_name(kb_name):
|
|
kb_ids.append(kb_helper.kb.kb_id)
|
|
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
|
|
|
|
if not kb_ids:
|
|
return {}
|
|
|
|
results = await self.retrieval_manager.retrieve(
|
|
query=query,
|
|
kb_ids=kb_ids,
|
|
kb_id_helper_map=kb_id_helper_map,
|
|
top_k_fusion=top_k_fusion,
|
|
top_m_final=top_m_final,
|
|
)
|
|
if not results:
|
|
return None
|
|
|
|
context_text = self._format_context(results)
|
|
|
|
results_dict = [
|
|
{
|
|
"chunk_id": r.chunk_id,
|
|
"doc_id": r.doc_id,
|
|
"kb_id": r.kb_id,
|
|
"kb_name": r.kb_name,
|
|
"doc_name": r.doc_name,
|
|
"chunk_index": r.metadata.get("chunk_index", 0),
|
|
"content": r.content,
|
|
"score": r.score,
|
|
"char_count": r.metadata.get("char_count", 0),
|
|
}
|
|
for r in results
|
|
]
|
|
|
|
return {
|
|
"context_text": context_text,
|
|
"results": results_dict,
|
|
}
|
|
|
|
def _format_context(self, results: list[RetrievalResult]) -> str:
|
|
"""格式化知识上下文
|
|
|
|
Args:
|
|
results: 检索结果列表
|
|
|
|
Returns:
|
|
str: 格式化的上下文文本
|
|
|
|
"""
|
|
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
|
|
|
for i, result in enumerate(results, 1):
|
|
lines.append(f"【知识 {i}】")
|
|
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
|
lines.append(f"内容: {result.content}")
|
|
lines.append(f"相关度: {result.score:.2f}")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
async def terminate(self):
|
|
"""终止所有知识库实例,关闭数据库连接"""
|
|
for kb_id, kb_helper in self.kb_insts.items():
|
|
try:
|
|
await kb_helper.terminate()
|
|
except Exception as e:
|
|
logger.error(f"关闭知识库 {kb_id} 失败: {e}")
|
|
|
|
self.kb_insts.clear()
|
|
|
|
# 关闭元数据数据库
|
|
if hasattr(self, "kb_db") and self.kb_db:
|
|
try:
|
|
await self.kb_db.close()
|
|
except Exception as e:
|
|
logger.error(f"关闭知识库元数据数据库失败: {e}")
|
|
|
|
async def upload_from_url(
|
|
self,
|
|
kb_id: str,
|
|
url: str,
|
|
chunk_size: int = 512,
|
|
chunk_overlap: int = 50,
|
|
batch_size: int = 32,
|
|
tasks_limit: int = 3,
|
|
max_retries: int = 3,
|
|
progress_callback=None,
|
|
) -> KBDocument:
|
|
"""从 URL 上传文档到指定的知识库
|
|
|
|
Args:
|
|
kb_id: 知识库 ID
|
|
url: 要提取内容的网页 URL
|
|
chunk_size: 文本块大小
|
|
chunk_overlap: 文本块重叠大小
|
|
batch_size: 批处理大小
|
|
tasks_limit: 并发任务限制
|
|
max_retries: 最大重试次数
|
|
progress_callback: 进度回调函数
|
|
|
|
Returns:
|
|
KBDocument: 上传的文档对象
|
|
|
|
Raises:
|
|
ValueError: 如果知识库不存在或 URL 为空
|
|
IOError: 如果网络请求失败
|
|
"""
|
|
kb_helper = await self.get_kb(kb_id)
|
|
if not kb_helper:
|
|
raise ValueError(f"Knowledge base with id {kb_id} not found.")
|
|
|
|
return await kb_helper.upload_from_url(
|
|
url=url,
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
batch_size=batch_size,
|
|
tasks_limit=tasks_limit,
|
|
max_retries=max_retries,
|
|
progress_callback=progress_callback,
|
|
)
|