diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ef063ae6..d23364b1 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -134,27 +134,9 @@ DEFAULT_CONFIG = { "persona": [], # deprecated "timezone": "Asia/Shanghai", "callback_api_base": "", - "default_kb_collection": "", # 默认知识库名称 + "default_kb_collection": "", # 默认知识库名称, 已经过时 "plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件 - "knowledge_base": { - "enabled": False, # 默认禁用,用户需要主动启用 - "embedding_provider_id": "", # 嵌入模型提供商 ID (为空时自动选择第一个) - "rerank_provider_id": "", # 重排序模型提供商 ID (为空时自动选择第一个) - "storage": { - "files_path": "data/knowledge_base", # 文件存储路径 - "vector_db_path": "data/knowledge_base/vectors", # 向量数据库路径 - }, - "chunking": { - "chunk_size": 512, # 文档块大小(字符数) - "chunk_overlap": 50, # 文档块重叠大小(字符数) - }, - "retrieval": { - "top_k_dense": 50, # 密集检索返回结果数 - "top_k_sparse": 50, # 稀疏检索返回结果数 - "top_m_final": 5, # 最终融合后返回的结果数 - "enable_rerank": True, # 是否启用重排序 - }, - }, + "kb_names": [], # 默认知识库名称列表 } diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index a07c4d14..a485734d 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -34,7 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_map -from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager +from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager class AstrBotCoreLifecycle: @@ -112,9 +112,7 @@ class AstrBotCoreLifecycle: self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) # 初始化知识库管理器 - self.kb_manager = KnowledgeBaseManager( - self.astrbot_config, self.provider_manager - ) + self.kb_manager = KnowledgeBaseManager(self.provider_manager) # 初始化提供给插件的上下文 self.star_context = Context( @@ -141,10 +139,6 @@ class AstrBotCoreLifecycle: await self.kb_manager.initialize() - # 注册知识库会话生命周期钩子(零侵入级联清理) - if self.kb_manager.is_initialized: - self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager) - # 初始化消息事件流水线调度器 self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler() diff --git a/astrbot/core/knowledge_base/__init__.py b/astrbot/core/knowledge_base/__init__.py deleted file mode 100644 index df403436..00000000 --- a/astrbot/core/knowledge_base/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -知识库管理模块 - -提供文档上传、解析、分块、向量化、检索等功能 -""" - -from astrbot.core.knowledge_base.models import ( - KBChunk, - KBDocument, - KBMedia, - KBSessionConfig, - KnowledgeBase, -) - -# 注意: 以下导入在对应模块实现后取消注释 -from .database import KBDatabase -from .manager import KBManager -from .manager_ops import KBManagerOps -from .session_config_db import SessionConfigDB - -# from .injector import KnowledgeBaseInjector - -__all__ = [ - "KnowledgeBase", - "KBDocument", - "KBChunk", - "KBMedia", - "KBSessionConfig", - "KBDatabase", - "SessionConfigDB", - "KBManager", - "KBManagerOps", - # "KnowledgeBaseInjector", -] diff --git a/astrbot/core/knowledge_base/database.py b/astrbot/core/knowledge_base/database.py deleted file mode 100644 index d48c9e52..00000000 --- a/astrbot/core/knowledge_base/database.py +++ /dev/null @@ -1,183 +0,0 @@ -"""知识库数据库操作类 - -该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。 - -注意: -- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db) -- 会话配置也存储在此数据库中,会话ID来源于主数据库 -""" - -from typing import Optional - -from sqlalchemy import func, select - -from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase -from astrbot.core.knowledge_base.models import ( - KBChunk, - KBDocument, - KBMedia, - KnowledgeBase, -) - - -class KBDatabase: - """知识库数据库操作类 - - 职责: - - 封装知识库、文档、块、多媒体和会话配置的数据库查询操作 - - 统一异常处理 - - 注意: - - 该类操作独立的知识库数据库 (kb.db) - - 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库 - """ - - def __init__(self, kb_db: KBSQLiteDatabase): - """初始化知识库数据库操作类 - - Args: - kb_db: 知识库独立数据库实例,而非主数据库 - """ - self.db = kb_db - - # ===== 知识库查询 ===== - - async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]: - """根据 ID 获取知识库""" - async with self.db.get_db() as session: - stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]: - """根据名称获取知识库""" - async with self.db.get_db() as session: - stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]: - """列出所有知识库""" - async with self.db.get_db() as session: - stmt = ( - select(KnowledgeBase) - .offset(offset) - .limit(limit) - .order_by(KnowledgeBase.created_at.desc()) - ) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def count_kbs(self) -> int: - """统计知识库数量""" - async with self.db.get_db() as session: - stmt = select(func.count(KnowledgeBase.id)) - result = await session.execute(stmt) - return result.scalar() or 0 - - # ===== 文档查询 ===== - - async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]: - """根据 ID 获取文档""" - async with self.db.get_db() as session: - stmt = select(KBDocument).where(KBDocument.doc_id == doc_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def list_documents_by_kb( - self, kb_id: str, offset: int = 0, limit: int = 100 - ) -> list[KBDocument]: - """列出知识库的所有文档""" - async with self.db.get_db() as session: - stmt = ( - select(KBDocument) - .where(KBDocument.kb_id == kb_id) - .offset(offset) - .limit(limit) - .order_by(KBDocument.created_at.desc()) - ) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def count_documents_by_kb(self, kb_id: str) -> int: - """统计知识库的文档数量""" - async with self.db.get_db() as session: - stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id) - result = await session.execute(stmt) - return result.scalar() or 0 - - # ===== 块查询 ===== - - async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]: - """根据 ID 获取块""" - async with self.db.get_db() as session: - stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]: - """根据知识库 ID 列表获取所有块""" - async with self.db.get_db() as session: - stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids)) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]: - """根据向量文档 ID 获取块""" - async with self.db.get_db() as session: - stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]: - """获取块及其关联的文档和知识库元数据""" - async with self.db.get_db() as session: - stmt = ( - select(KBChunk, KBDocument, KnowledgeBase) - .join(KBDocument, KBChunk.doc_id == KBDocument.doc_id) - .join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id) - .where(KBChunk.chunk_id == chunk_id) - ) - result = await session.execute(stmt) - row = result.first() - - if not row: - return None - - chunk, doc, kb = row - return { - "chunk": chunk, - "document": doc, - "knowledge_base": kb, - } - - async def list_chunks_by_doc( - self, doc_id: str, offset: int = 0, limit: int = 100 - ) -> list[KBChunk]: - """列出文档的所有块""" - async with self.db.get_db() as session: - stmt = ( - select(KBChunk) - .where(KBChunk.doc_id == doc_id) - .offset(offset) - .limit(limit) - .order_by(KBChunk.chunk_index) - ) - result = await session.execute(stmt) - return list(result.scalars().all()) - - # ===== 多媒体查询 ===== - - async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]: - """列出文档的所有多媒体资源""" - async with self.db.get_db() as session: - stmt = select(KBMedia).where(KBMedia.doc_id == doc_id) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]: - """根据 ID 获取多媒体资源""" - async with self.db.get_db() as session: - stmt = select(KBMedia).where(KBMedia.media_id == media_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() diff --git a/astrbot/core/knowledge_base/injector.py b/astrbot/core/knowledge_base/injector.py deleted file mode 100644 index d2a81391..00000000 --- a/astrbot/core/knowledge_base/injector.py +++ /dev/null @@ -1,112 +0,0 @@ -"""知识库上下文注入器 - -负责检索相关知识并格式化为 LLM 可用的上下文文本 -""" - -from typing import List, Optional - -from astrbot.core.knowledge_base.database import KBDatabase -from astrbot.core.knowledge_base.retrieval.manager import ( - RetrievalManager, - RetrievalResult, -) -from .vec_db_factory import VecDBFactory - - -class KnowledgeBaseInjector: - """知识库上下文注入器 - - 职责: - - 检索相关知识 - - 格式化为上下文文本 - - 注入到 LLM Prompt - """ - - def __init__( - self, - kb_db: KBDatabase, - vec_db_factory: VecDBFactory, - retrieval_manager: RetrievalManager, - ): - """初始化知识库上下文注入器 - - Args: - kb_db: 知识库数据库实例 - retrieval_manager: 检索管理器实例 - """ - self.kb_db = kb_db - self.vec_db_factory = vec_db_factory - self.retrieval_manager = retrieval_manager - - async def retrieve_and_inject( - self, - kb_ids: list[str], - query: str, - top_k: int = 5, - ) -> Optional[dict]: - """检索并注入知识库上下文 - - Args: - query: 用户查询 - top_k: 返回结果数量 - - Returns: - Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None - { - "context_text": str, # 格式化的上下文文本 - "results": List[dict], # 原始检索结果列表 - } - """ - # 2. 检索知识 - results = await self.retrieval_manager.retrieve( - vec_db_factory=self.vec_db_factory, - query=query, - kb_ids=kb_ids, - top_m_final=top_k, - ) - - if not results: - return None - - # 3. 格式化上下文 - context_text = self._format_context(results) - - # 4. 转换结果为字典格式 - results_dict = [ - { - "chunk_id": r.chunk_id, - "doc_id": r.doc_id, - "kb_id": r.kb_id, - "kb_name": r.kb_name, - "doc_name": r.doc_name, - "chunk_index": r.metadata.get("chunk_index", 0), - "content": r.content, - "score": r.score, - } - for r in results - ] - - return { - "context_text": context_text, - "results": results_dict, - } - - def _format_context(self, results: List[RetrievalResult]) -> str: - """格式化知识上下文 - - Args: - results: 检索结果列表 - - Returns: - str: 格式化的上下文文本 - """ - lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"] - - for i, result in enumerate(results, 1): - lines.append(f"【知识 {i}】") - lines.append(f"来源: {result.kb_name} / {result.doc_name}") - lines.append(f"内容: {result.content}") - lines.append(f"相关度: {result.score:.2f}") - lines.append("") - - return "\n".join(lines) diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py new file mode 100644 index 00000000..9724818d --- /dev/null +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -0,0 +1,383 @@ +from contextlib import asynccontextmanager +from pathlib import Path + +from sqlmodel import SQLModel, col, desc +from sqlalchemy import text, func, select, update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from astrbot.core import logger +from astrbot.core.knowledge_base.models import ( + KBChunk, + KBDocument, + KBMedia, + KnowledgeBase, +) + +from typing import Optional + + +class KBSQLiteDatabase: + def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: + """初始化知识库数据库 + + Args: + db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db + """ + self.db_path = db_path + self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" + self.inited = False + + # 确保目录存在 + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + + # 创建异步引擎 + self.engine = create_async_engine( + self.DATABASE_URL, + echo=False, + pool_pre_ping=True, + pool_recycle=3600, + ) + + # 创建会话工厂 + self.async_session = async_sessionmaker( + self.engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + @asynccontextmanager + async def get_db(self): + """获取数据库会话 + + 用法: + async with kb_db.get_db() as session: + # 执行数据库操作 + result = await session.execute(stmt) + """ + async with self.async_session() as session: + yield session + + async def initialize(self) -> None: + """初始化数据库,创建表并配置 SQLite 参数""" + async with self.engine.begin() as conn: + # 创建所有知识库相关表 + await conn.run_sync(SQLModel.metadata.create_all) + + # 配置 SQLite 性能优化参数 + await conn.execute(text("PRAGMA journal_mode=WAL")) + await conn.execute(text("PRAGMA synchronous=NORMAL")) + await conn.execute(text("PRAGMA cache_size=20000")) + await conn.execute(text("PRAGMA temp_store=MEMORY")) + await conn.execute(text("PRAGMA mmap_size=134217728")) + await conn.execute(text("PRAGMA optimize")) + await conn.commit() + + self.inited = True + + async def migrate_to_v1(self) -> None: + """执行知识库数据库 v1 迁移 + + 创建所有必要的索引以优化查询性能 + """ + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + # 创建知识库表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_kb_kb_id " + "ON knowledge_bases(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_kb_name " + "ON knowledge_bases(kb_name)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_kb_created_at " + "ON knowledge_bases(created_at)" + ) + ) + + # 创建文档表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_doc_id " + "ON kb_documents(doc_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_kb_id " + "ON kb_documents(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_name " + "ON kb_documents(doc_name)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_type " + "ON kb_documents(file_type)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_doc_created_at " + "ON kb_documents(created_at)" + ) + ) + + # 创建块表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id " + "ON kb_chunks(chunk_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_doc_id " + "ON kb_chunks(doc_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id " + "ON kb_chunks(vec_doc_id)" + ) + ) + + # 创建多媒体表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_media_id " + "ON kb_media(media_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_doc_id " + "ON kb_media(doc_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_media_type " + "ON kb_media(media_type)" + ) + ) + + # 创建会话配置表索引 + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_session_config_scope_id " + "ON kb_session_config(scope_id)" + ) + ) + await session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_session_config_scope " + "ON kb_session_config(scope)" + ) + ) + + await session.commit() + + async def close(self) -> None: + """关闭数据库连接""" + await self.engine.dispose() + logger.info(f"知识库数据库已关闭: {self.db_path}") + + async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]: + """根据 ID 获取知识库""" + async with self.get_db() as session: + stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]: + """根据名称获取知识库""" + async with self.get_db() as session: + stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]: + """列出所有知识库""" + async with self.get_db() as session: + stmt = ( + select(KnowledgeBase) + .offset(offset) + .limit(limit) + .order_by(desc(KnowledgeBase.created_at)) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def count_kbs(self) -> int: + """统计知识库数量""" + async with self.get_db() as session: + stmt = select(func.count(col(KnowledgeBase.id))) + result = await session.execute(stmt) + return result.scalar() or 0 + + # ===== 文档查询 ===== + + async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]: + """根据 ID 获取文档""" + async with self.get_db() as session: + stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def list_documents_by_kb( + self, kb_id: str, offset: int = 0, limit: int = 100 + ) -> list[KBDocument]: + """列出知识库的所有文档""" + async with self.get_db() as session: + stmt = ( + select(KBDocument) + .where(col(KBDocument.kb_id) == kb_id) + .offset(offset) + .limit(limit) + .order_by(desc(KBDocument.created_at)) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def count_documents_by_kb(self, kb_id: str) -> int: + """统计知识库的文档数量""" + async with self.get_db() as session: + stmt = select(func.count(col(KBDocument.id))).where( + col(KBDocument.kb_id) == kb_id + ) + result = await session.execute(stmt) + return result.scalar() or 0 + + # ===== 块查询 ===== + + async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]: + """根据 ID 获取块""" + async with self.get_db() as session: + stmt = select(KBChunk).where(col(KBChunk.chunk_id) == chunk_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]: + """根据知识库 ID 列表获取所有块""" + async with self.get_db() as session: + stmt = select(KBChunk).where(col(KBChunk.kb_id).in_(kb_ids)) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]: + """根据向量文档 ID 获取块""" + async with self.get_db() as session: + stmt = select(KBChunk).where(col(KBChunk.vec_doc_id) == vec_doc_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def get_chunks_by_doc_id( + self, doc_id: str, offset: int = 0, limit: int = 100 + ) -> list[KBChunk]: + """根据文档 ID 获取所有块""" + async with self.get_db() as session: + stmt = ( + select(KBChunk) + .where(col(KBChunk.doc_id) == doc_id) + .offset(offset) + .limit(limit) + .order_by(col(KBChunk.chunk_index)) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]: + """获取块及其关联的文档和知识库元数据""" + async with self.get_db() as session: + stmt = ( + select(KBChunk, KBDocument, KnowledgeBase) + .join(KBDocument, col(KBChunk.doc_id) == col(KBDocument.doc_id)) + .join(KnowledgeBase, col(KBChunk.kb_id) == col(KnowledgeBase.kb_id)) + .where(col(KBChunk.chunk_id) == chunk_id) + ) + result = await session.execute(stmt) + row = result.first() + + if not row: + return None + + chunk, doc, kb = row + return { + "chunk": chunk, + "document": doc, + "knowledge_base": kb, + } + + async def list_chunks_by_doc( + self, doc_id: str, offset: int = 0, limit: int = 100 + ) -> list[KBChunk]: + """列出文档的所有块""" + async with self.get_db() as session: + stmt = ( + select(KBChunk) + .where(col(KBChunk.doc_id) == doc_id) + .offset(offset) + .limit(limit) + .order_by(col(KBChunk.chunk_index)) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + # ===== 多媒体查询 ===== + + async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]: + """列出文档的所有多媒体资源""" + async with self.get_db() as session: + stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]: + """根据 ID 获取多媒体资源""" + async with self.get_db() as session: + stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def update_kb_stats(self, kb_id: str): + """更新知识库统计信息""" + async with self.get_db() as session: + async with session.begin(): + update_stmt = ( + update(KnowledgeBase) + .where(col(KnowledgeBase.kb_id) == kb_id) + .values( + doc_count=select(func.count(col(KBDocument.id))) + .where(col(KBDocument.kb_id) == kb_id) + .scalar_subquery(), + chunk_count=select(func.count(col(KBChunk.id))) + .where(col(KBChunk.kb_id) == kb_id) + .scalar_subquery(), + ) + ) + + await session.execute(update_stmt) + await session.commit() diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py new file mode 100644 index 00000000..59277e23 --- /dev/null +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -0,0 +1,248 @@ +import uuid +import aiofiles +from pathlib import Path +from .models import KnowledgeBase, KBDocument, KBChunk, KBMedia +from .kb_db_sqlite import KBSQLiteDatabase +from astrbot.core.db.vec_db.base import BaseVecDB +from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB +from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider +from astrbot.core.provider.manager import ProviderManager +from .parsers.base import BaseParser +from .chunking.base import BaseChunker +from astrbot.core import logger + + +class KBHelper: + vec_db: BaseVecDB + + def __init__( + self, + kb_db: KBSQLiteDatabase, + kb: KnowledgeBase, + provider_manager: ProviderManager, + kb_root_dir: str, + chunker: BaseChunker, + parsers: dict[str, BaseParser], + ): + self.kb_db = kb_db + self.kb = kb + self.prov_mgr = provider_manager + self.kb_root_dir = kb_root_dir + self.parsers = parsers + self.chunker = chunker + + self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id + self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id + self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id + + self.kb_medias_dir.mkdir(parents=True, exist_ok=True) + self.kb_files_dir.mkdir(parents=True, exist_ok=True) + + async def initialize(self): + await self._ensure_vec_db() + + async def get_ep(self) -> EmbeddingProvider: + if not self.kb.embedding_provider_id: + raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") + ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id( + self.kb.embedding_provider_id + ) # type: ignore + if not ep: + raise ValueError( + f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider" + ) + return ep + + async def get_rp(self) -> RerankProvider | None: + if not self.kb.rerank_provider_id: + return None + rp: RerankProvider = await self.prov_mgr.get_provider_by_id( + self.kb.rerank_provider_id + ) # type: ignore + if not rp: + raise ValueError( + f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider" + ) + return rp + + async def _ensure_vec_db(self) -> FaissVecDB: + if not self.kb.embedding_provider_id: + raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") + + ep = await self.get_ep() + rp = await self.get_rp() + + vec_db = FaissVecDB( + doc_store_path=str(self.kb_dir / "doc.db"), + index_store_path=str(self.kb_dir / "index.faiss"), + embedding_provider=ep, + rerank_provider=rp, + ) + await vec_db.initialize() + self.vec_db = vec_db + return vec_db + + async def delete_vec_db(self): + await self.terminate() + if self.kb_dir.exists(): + for item in self.kb_dir.iterdir(): + if item.is_file(): + item.unlink() + self.kb_dir.rmdir() + + async def terminate(self): + if self.vec_db: + await self.vec_db.close() + + async def upload_document( + self, + file_name: str, + file_content: bytes, + file_type: str, + ) -> KBDocument: + """上传并处理文档(带原子性保证和失败清理) + + 流程: + 1. 保存原始文件 + 2. 解析文档内容 + 3. 提取多媒体资源 (TODO) + 4. 分块处理 + 5. 生成向量并存储 + 6. 保存元数据(事务) + 7. 更新统计 + """ + await self._ensure_vec_db() + doc_id = str(uuid.uuid4()) + media_paths: list[Path] = [] + vec_doc_ids = [] + + file_path = self.kb_files_dir / f"{doc_id}.{file_type}" + async with aiofiles.open(file_path, "wb") as f: + await f.write(file_content) + + try: + parser = self.parsers.get(file_type) + if not parser: + raise ValueError(f"不支持的文件类型: {file_type}") + parse_result = await parser.parse(file_content, file_name) + text_content = parse_result.text + media_items = parse_result.media + + # 保存媒体文件 + saved_media = [] + for media_item in media_items: + media = await self._save_media( + doc_id=doc_id, + media_type=media_item.media_type, + file_name=media_item.file_name, + content=media_item.content, + mime_type=media_item.mime_type, + ) + saved_media.append(media) + media_paths.append(Path(media.file_path)) + + # 分块并生成向量 + saved_chunks = [] + chunks_text = await self.chunker.chunk(text_content) + for idx, chunk_text in enumerate(chunks_text): + vec_doc_id = await self.vec_db.insert( + content=chunk_text, + metadata={ + "kb_id": self.kb.kb_id, + "doc_id": doc_id, + "chunk_index": idx, + }, + ) + vec_doc_ids.append(str(vec_doc_id)) + + chunk = KBChunk( + doc_id=doc_id, + kb_id=self.kb.kb_id, + chunk_index=idx, + content=chunk_text, + char_count=len(chunk_text), + vec_doc_id=str(vec_doc_id), + ) + saved_chunks.append(chunk) + + # 保存文档和块的元数据 + doc = KBDocument( + doc_id=doc_id, + kb_id=self.kb.kb_id, + doc_name=file_name, + file_type=file_type, + file_size=len(file_content), + file_path=str(file_path), + chunk_count=len(saved_chunks), + media_count=0, + ) + async with self.kb_db.get_db() as session: + async with session.begin(): + session.add(doc) + for chunk in saved_chunks: + session.add(chunk) + for media in saved_media: + session.add(media) + await session.commit() + + await session.refresh(doc) + + await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id) + + return doc + except Exception as e: + logger.error(f"上传文档失败: {e}") + if file_path.exists(): + file_path.unlink() + + for media_path in media_paths: + try: + if media_path.exists(): + media_path.unlink() + except Exception as me: + logger.warning(f"清理多媒体文件失败 {media_path}: {me}") + + raise e + + async def list_documents( + self, offset: int = 0, limit: int = 100 + ) -> list[KBDocument]: + """列出知识库的所有文档""" + docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit) + return docs + + async def get_document(self, doc_id: str) -> KBDocument | None: + """获取单个文档""" + doc = await self.kb_db.get_document_by_id(doc_id) + return doc + + async def _save_media( + self, + doc_id: str, + media_type: str, + file_name: str, + content: bytes, + mime_type: str, + ) -> KBMedia: + """保存多媒体资源""" + media_id = str(uuid.uuid4()) + ext = Path(file_name).suffix + + # 保存文件 + file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}" + file_path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(file_path, "wb") as f: + await f.write(content) + + media = KBMedia( + media_id=media_id, + doc_id=doc_id, + kb_id=self.kb.kb_id, + media_type=media_type, + file_name=file_name, + file_path=str(file_path), + file_size=len(content), + mime_type=mime_type, + ) + + return media diff --git a/astrbot/core/knowledge_base/kb_manager_lifecycle.py b/astrbot/core/knowledge_base/kb_manager_lifecycle.py deleted file mode 100644 index c0b709a5..00000000 --- a/astrbot/core/knowledge_base/kb_manager_lifecycle.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -知识库管理器 -负责知识库模块的初始化、配置和资源管理 - -架构说明: -- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db) -- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联 -""" - -from pathlib import Path -from astrbot.core import logger -from astrbot.core.provider.manager import ProviderManager -from .injector import KnowledgeBaseInjector -from .retrieval.manager import RetrievalManager -from .retrieval.sparse_retriever import SparseRetriever -from .retrieval.rank_fusion import RankFusion -from .kb_sqlite import KBSQLiteDatabase -from .database import KBDatabase -from .vec_db_factory import VecDBFactory -from .manager import KBManager -from .parsers.text_parser import TextParser -from .parsers.pdf_parser import PDFParser -from .chunking.fixed_size import FixedSizeChunker - - -class KnowledgeBaseManager: - """知识库管理器 - - 职责: - - 知识库模块的初始化 - - Embedding Provider 和 Rerank Provider 的选择 - - 各个子组件的协调管理 - - 注册会话删除回调,实现级联清理 - - 架构说明: - - 知识库数据存储在独立数据库 (kb.db) - - 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库 - - 通过回调机制实现与主数据库的生命周期同步 - """ - - kb_db: KBSQLiteDatabase - vec_db_factory: VecDBFactory - kb_database: KBDatabase - kb_manager: KBManager - retrieval_manager: RetrievalManager - kb_injector: KnowledgeBaseInjector - - def __init__( - self, - config: dict, - provider_manager: ProviderManager, - ): - """初始化知识库管理器 - - Args: - config: 配置字典 - provider_manager: Provider 管理器 - """ - self.config = config.get("knowledge_base", {}) - self.provider_manager = provider_manager - self._initialized = False - self._session_deleted_callback_registered = False - - async def initialize(self): - """初始化知识库模块""" - if not self.config.get("enabled", False): - logger.info("知识库功能未启用") - return - - try: - logger.info("正在初始化知识库模块...") - - # 初始化数据库 - await self._init_kb_database() - - # 初始化向量数据库工厂 - await self._init_vector_db_factory() - - # 初始化解析器和分块器 - parsers = { - "txt": TextParser(), - "md": TextParser(), - "markdown": TextParser(), - "pdf": PDFParser(), - } - chunking_config = self.config.get("chunking", {}) - chunker = FixedSizeChunker( - chunk_size=chunking_config.get("chunk_size", 512), - chunk_overlap=chunking_config.get("chunk_overlap", 50), - ) - - # 初始化知识库管理器 - files_path = self.config.get("storage", {}).get( - "files_path", "data/knowledge_base" - ) - self.kb_manager = KBManager( - db=self.kb_db, - vec_db_factory=self.vec_db_factory, - storage_path=files_path, - parsers=parsers, - chunker=chunker, - provider_manager=self.provider_manager, - ) - - # 初始化检索管理器 - sparse_retriever = SparseRetriever(self.kb_database) - rank_fusion = RankFusion(self.kb_database) - self.retrieval_manager = RetrievalManager( - vec_db_factory=self.vec_db_factory, - sparse_retriever=sparse_retriever, - rank_fusion=rank_fusion, - kb_db=self.kb_database, - ) - - # 初始化上下文注入器 - self.kb_injector = KnowledgeBaseInjector( - kb_db=self.kb_database, - vec_db_factory=self.vec_db_factory, - retrieval_manager=self.retrieval_manager, - ) - - self._initialized = True - logger.info("知识库模块初始化完成") - - except ImportError as e: - logger.error(f"知识库模块导入失败: {e}") - logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25") - except Exception as e: - logger.error(f"知识库模块初始化失败: {e}") - import traceback - - logger.error(traceback.format_exc()) - - async def _init_kb_database(self): - """初始化知识库独立数据库""" - db_path = self.config.get("storage", {}).get( - "kb_db_path", "data/knowledge_base/kb.db" - ) - Path(db_path).parent.mkdir(parents=True, exist_ok=True) - - self.kb_db = KBSQLiteDatabase(db_path) - await self.kb_db.initialize() - await self.kb_db.migrate_to_v1() - self.kb_database = KBDatabase(self.kb_db) - logger.info(f"KnowledgeBase database initialized: {db_path}") - - async def _init_vector_db_factory(self): - """初始化向量数据库工厂""" - storage_path = self.config.get("storage", {}).get( - "vector_db_path", "data/knowledge_base/vectors" - ) - Path(storage_path).mkdir(parents=True, exist_ok=True) - self.vec_db_factory = VecDBFactory(storage_base_path=storage_path) - - @property - def is_initialized(self) -> bool: - """检查是否已初始化""" - return self._initialized - - def get_kb_manager(self): - """获取知识库管理器""" - return self.kb_manager if self._initialized else None - - def get_kb_injector(self): - """获取知识库上下文注入器""" - return self.kb_injector if self._initialized else None - - async def reinitialize(self): - """重新初始化知识库模块 - - 用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后) - """ - if self._initialized: - logger.info("知识库模块已初始化,将重新初始化") - await self.terminate() - - await self.initialize() - return self._initialized - - async def terminate(self): - """终止知识库模块,清理资源""" - if not self._initialized: - return - - logger.info("正在终止知识库模块...") - - # 关闭向量数据库工厂(关闭所有向量数据库实例) - if self.vec_db_factory: - try: - await self.vec_db_factory.close_all() - logger.debug("向量数据库工厂已关闭") - except Exception as e: - logger.warning(f"关闭向量数据库工厂时出错: {e}") - - # 关闭知识库独立数据库连接 - if self.kb_db: - try: - await self.kb_db.close() - logger.debug("知识库数据库已关闭") - except Exception as e: - logger.warning(f"关闭知识库数据库时出错: {e}") - - self._initialized = False - - logger.info("知识库模块已终止") diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py new file mode 100644 index 00000000..2d8ff872 --- /dev/null +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -0,0 +1,275 @@ +import traceback +from pathlib import Path +from astrbot.core import logger +from astrbot.core.provider.manager import ProviderManager + +from .retrieval.manager import RetrievalManager, RetrievalResult +from .retrieval.sparse_retriever import SparseRetriever +from .retrieval.rank_fusion import RankFusion +from .kb_db_sqlite import KBSQLiteDatabase + +from .parsers.text_parser import TextParser +from .parsers.pdf_parser import PDFParser +from .chunking.fixed_size import FixedSizeChunker +from .kb_helper import KBHelper + +from .models import KnowledgeBase + + +FILES_PATH = "data/knowledge_base" +DB_PATH = Path(FILES_PATH) / "kb.db" +"""Knowledge Base storage root directory""" +PARSERS = { + "txt": TextParser(), + "md": TextParser(), + "markdown": TextParser(), + "pdf": PDFParser(), +} +CHUNKER = FixedSizeChunker() + + +class KnowledgeBaseManager: + kb_db: KBSQLiteDatabase + retrieval_manager: RetrievalManager + + def __init__( + self, + provider_manager: ProviderManager, + ): + Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) + self.provider_manager = provider_manager + self._session_deleted_callback_registered = False + + self.kb_insts: dict[str, KBHelper] = {} + + async def initialize(self): + """初始化知识库模块""" + try: + logger.info("正在初始化知识库模块...") + + # 初始化数据库 + await self._init_kb_database() + + # 初始化检索管理器 + sparse_retriever = SparseRetriever(self.kb_db) + rank_fusion = RankFusion(self.kb_db) + self.retrieval_manager = RetrievalManager( + sparse_retriever=sparse_retriever, + rank_fusion=rank_fusion, + kb_db=self.kb_db, + ) + await self.load_kbs() + + except ImportError as e: + logger.error(f"知识库模块导入失败: {e}") + logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25") + except Exception as e: + logger.error(f"知识库模块初始化失败: {e}") + logger.error(traceback.format_exc()) + + async def _init_kb_database(self): + self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) + await self.kb_db.initialize() + await self.kb_db.migrate_to_v1() + logger.info(f"KnowledgeBase database initialized: {DB_PATH}") + + async def load_kbs(self): + """加载所有知识库实例""" + kb_records = await self.kb_db.list_kbs() + for record in kb_records: + kb_helper = KBHelper( + kb_db=self.kb_db, + kb=record, + provider_manager=self.provider_manager, + kb_root_dir=FILES_PATH, + chunker=CHUNKER, + parsers=PARSERS, + ) + await kb_helper.initialize() + self.kb_insts[record.kb_id] = kb_helper + + async def create_kb( + self, + kb_name: str, + description: str | None = None, + emoji: str | None = None, + embedding_provider_id: str | None = None, + rerank_provider_id: str | None = None, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + top_k_dense: int | None = None, + top_k_sparse: int | None = None, + top_m_final: int | None = None, + ) -> KBHelper: + """创建新的知识库实例""" + kb = KnowledgeBase( + kb_name=kb_name, + description=description, + emoji=emoji or "📚", + embedding_provider_id=embedding_provider_id, + rerank_provider_id=rerank_provider_id, + chunk_size=chunk_size if chunk_size is not None else 512, + chunk_overlap=chunk_overlap if chunk_overlap is not None else 50, + top_k_dense=top_k_dense if top_k_dense is not None else 50, + top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, + top_m_final=top_m_final if top_m_final is not None else 5, + ) + async with self.kb_db.get_db() as session: + session.add(kb) + await session.commit() + await session.refresh(kb) + + kb_helper = KBHelper( + kb_db=self.kb_db, + kb=kb, + provider_manager=self.provider_manager, + kb_root_dir=FILES_PATH, + chunker=CHUNKER, + parsers=PARSERS, + ) + await kb_helper.initialize() + self.kb_insts[kb.kb_id] = kb_helper + return kb_helper + + async def get_kb(self, kb_id: str) -> KBHelper | None: + """获取知识库实例""" + if kb_id in self.kb_insts: + return self.kb_insts[kb_id] + + async def get_kb_by_name(self, kb_name: str) -> KBHelper | None: + """通过名称获取知识库实例""" + for kb_helper in self.kb_insts.values(): + if kb_helper.kb.kb_name == kb_name: + return kb_helper + return None + + async def delete_kb(self, kb_id: str) -> bool: + """删除知识库实例""" + kb_helper = await self.get_kb(kb_id) + if not kb_helper: + return False + + await kb_helper.delete_vec_db() + async with self.kb_db.get_db() as session: + await session.delete(kb_helper.kb) + await session.commit() + + self.kb_insts.pop(kb_id, None) + return True + + async def list_kbs(self) -> list[KnowledgeBase]: + """列出所有知识库实例""" + kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()] + return kbs + + async def update_kb( + self, + kb_id: str, + kb_name: str, + description: str | None = None, + emoji: str | None = None, + embedding_provider_id: str | None = None, + rerank_provider_id: str | None = None, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + top_k_dense: int | None = None, + top_k_sparse: int | None = None, + top_m_final: int | None = None, + ) -> KBHelper | None: + """更新知识库实例""" + kb_helper = await self.get_kb(kb_id) + if not kb_helper: + return None + + kb = kb_helper.kb + if kb_name is not None: + kb.kb_name = kb_name + if description is not None: + kb.description = description + if emoji is not None: + kb.emoji = emoji + if embedding_provider_id is not None: + kb.embedding_provider_id = embedding_provider_id + if rerank_provider_id is not None: + kb.rerank_provider_id = rerank_provider_id + if chunk_size is not None: + kb.chunk_size = chunk_size + if chunk_overlap is not None: + kb.chunk_overlap = chunk_overlap + if top_k_dense is not None: + kb.top_k_dense = top_k_dense + if top_k_sparse is not None: + kb.top_k_sparse = top_k_sparse + if top_m_final is not None: + kb.top_m_final = top_m_final + async with self.kb_db.get_db() as session: + session.add(kb) + await session.commit() + await session.refresh(kb) + + async def retrieve( + self, + query: str, + kb_names: list[str], + top_m_final: int = 5, + ) -> dict | None: + """从指定知识库中检索相关内容""" + kb_ids = [] + kb_id_helper_map = {} + for kb_name in kb_names: + if kb_helper := await self.get_kb_by_name(kb_name): + kb_ids.append(kb_helper.kb.kb_id) + kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper + + if not kb_ids: + return {} + + results = await self.retrieval_manager.retrieve( + query=query, + kb_ids=kb_ids, + kb_id_helper_map=kb_id_helper_map, + top_m_final=top_m_final, + ) + if not results: + return None + + context_text = self._format_context(results) + + results_dict = [ + { + "chunk_id": r.chunk_id, + "doc_id": r.doc_id, + "kb_id": r.kb_id, + "kb_name": r.kb_name, + "doc_name": r.doc_name, + "chunk_index": r.metadata.get("chunk_index", 0), + "content": r.content, + "score": r.score, + } + for r in results + ] + + return { + "context_text": context_text, + "results": results_dict, + } + + def _format_context(self, results: list[RetrievalResult]) -> str: + """格式化知识上下文 + + Args: + results: 检索结果列表 + + Returns: + str: 格式化的上下文文本 + """ + lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"] + + for i, result in enumerate(results, 1): + lines.append(f"【知识 {i}】") + lines.append(f"来源: {result.kb_name} / {result.doc_name}") + lines.append(f"内容: {result.content}") + lines.append(f"相关度: {result.score:.2f}") + lines.append("") + + return "\n".join(lines) diff --git a/astrbot/core/knowledge_base/kb_sqlite.py b/astrbot/core/knowledge_base/kb_sqlite.py deleted file mode 100644 index 526b6277..00000000 --- a/astrbot/core/knowledge_base/kb_sqlite.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -知识库独立 SQLite 数据库 - -该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。 -职责: -- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media) -- 提供数据库连接和会话管理 -- 执行数据库迁移和初始化 -""" - -from contextlib import asynccontextmanager -from pathlib import Path - -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine - -from astrbot.core import logger - - -class KBSQLiteDatabase: - """知识库独立 SQLite 数据库 - - 与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。 - - 特点: - - 数据隔离: 知识库数据不会影响主数据库格式 - - 独立备份: 可以单独备份和恢复知识库数据 - - 性能隔离: 大量知识库查询不会影响主业务性能 - """ - - def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: - """初始化知识库数据库 - - Args: - db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db - """ - self.db_path = db_path - self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" - self.inited = False - - # 确保目录存在 - Path(db_path).parent.mkdir(parents=True, exist_ok=True) - - # 创建异步引擎 - self.engine = create_async_engine( - self.DATABASE_URL, - echo=False, - pool_pre_ping=True, - pool_recycle=3600, - ) - - # 创建会话工厂 - self.async_session = async_sessionmaker( - self.engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - @asynccontextmanager - async def get_db(self): - """获取数据库会话 - - 用法: - async with kb_db.get_db() as session: - # 执行数据库操作 - result = await session.execute(stmt) - """ - async with self.async_session() as session: - yield session - - async def initialize(self) -> None: - """初始化数据库,创建表并配置 SQLite 参数""" - # noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表 - from astrbot.core.knowledge_base.models import ( # noqa: F401 - KBChunk, - KBDocument, - KBMedia, - KBSessionConfig, - KnowledgeBase, - ) - from sqlmodel import SQLModel - - async with self.engine.begin() as conn: - # 创建所有知识库相关表 - await conn.run_sync(SQLModel.metadata.create_all) - - # 配置 SQLite 性能优化参数 - await conn.execute(text("PRAGMA journal_mode=WAL")) - await conn.execute(text("PRAGMA synchronous=NORMAL")) - await conn.execute(text("PRAGMA cache_size=20000")) - await conn.execute(text("PRAGMA temp_store=MEMORY")) - await conn.execute(text("PRAGMA mmap_size=134217728")) - await conn.execute(text("PRAGMA optimize")) - await conn.commit() - - self.inited = True - logger.info(f"知识库数据库已初始化: {self.db_path}") - - async def migrate_to_v1(self) -> None: - """执行知识库数据库 v1 迁移 - - 创建所有必要的索引以优化查询性能 - """ - async with self.get_db() as session: - session: AsyncSession - async with session.begin(): - # 创建知识库表索引 - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_kb_kb_id " - "ON knowledge_bases(kb_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_kb_name " - "ON knowledge_bases(kb_name)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_kb_created_at " - "ON knowledge_bases(created_at)" - ) - ) - - # 创建文档表索引 - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_doc_doc_id " - "ON kb_documents(doc_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_doc_kb_id " - "ON kb_documents(kb_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_doc_name " - "ON kb_documents(doc_name)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_doc_type " - "ON kb_documents(file_type)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_doc_created_at " - "ON kb_documents(created_at)" - ) - ) - - # 创建块表索引 - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id " - "ON kb_chunks(chunk_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_chunk_doc_id " - "ON kb_chunks(doc_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id " - "ON kb_chunks(vec_doc_id)" - ) - ) - - # 创建多媒体表索引 - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_media_media_id " - "ON kb_media(media_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_media_doc_id " - "ON kb_media(doc_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_media_type " - "ON kb_media(media_type)" - ) - ) - - # 创建会话配置表索引 - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_session_config_scope_id " - "ON kb_session_config(scope_id)" - ) - ) - await session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_session_config_scope " - "ON kb_session_config(scope)" - ) - ) - - await session.commit() - - logger.info("知识库数据库迁移 v1 完成") - - async def close(self) -> None: - """关闭数据库连接""" - await self.engine.dispose() - logger.info(f"知识库数据库已关闭: {self.db_path}") diff --git a/astrbot/core/knowledge_base/manager.py b/astrbot/core/knowledge_base/manager.py deleted file mode 100644 index f6a6c7f8..00000000 --- a/astrbot/core/knowledge_base/manager.py +++ /dev/null @@ -1,430 +0,0 @@ -"""知识库管理器 - -该模块提供知识库的CRUD操作和文档上传处理流程。 -""" - -import uuid -from pathlib import Path -from typing import Optional - -import aiofiles -from sqlalchemy import func, select, update - -from .kb_sqlite import KBSQLiteDatabase -from astrbot.core.knowledge_base.chunking.base import BaseChunker -from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase -from astrbot.core.knowledge_base.parsers.base import BaseParser -from .vec_db_factory import VecDBFactory - -class KBManager: - """知识库管理器 - - 职责: - - 知识库的 CRUD 操作 - - 文档上传与解析 - - 文档块生成与存储 - - 多媒体资源管理 - """ - - def __init__( - self, - db: KBSQLiteDatabase, - vec_db_factory: VecDBFactory, - storage_path: str, - parsers: dict[str, BaseParser], - chunker: BaseChunker, - provider_manager=None, - ): - self.db = db - self.vec_db_factory = vec_db_factory - self.storage_path = Path(storage_path) - self.media_path = self.storage_path / "media" - self.files_path = self.storage_path / "files" - self.parsers = parsers - self.chunker = chunker - self.provider_manager = provider_manager - - # 确保目录存在 - self.media_path.mkdir(parents=True, exist_ok=True) - self.files_path.mkdir(parents=True, exist_ok=True) - - async def _get_embedding_provider_for_kb(self, kb_id: str): - """根据知识库配置获取 Embedding Provider - - Args: - kb_id: 知识库 ID - - Returns: - EmbeddingProvider: Embedding Provider 实例 - - Raises: - ValueError: 如果找不到合适的 embedding provider - """ - from astrbot.core.knowledge_base.database import KBDatabase - - # 获取知识库配置 - kb_database = KBDatabase(self.db) - kb = await kb_database.get_kb_by_id(kb_id) - if not kb: - raise ValueError(f"知识库不存在: {kb_id}") - - embedding_provider_id = kb.embedding_provider_id - - # 如果没有 provider_manager,使用默认的第一个 - if not self.provider_manager: - raise ValueError("Provider Manager 未初始化") - - embedding_providers = self.provider_manager.embedding_provider_insts - if not embedding_providers: - raise ValueError("系统中没有可用的 Embedding Provider") - - # 如果指定了 provider ID,则查找该 provider - if embedding_provider_id: - for provider in embedding_providers: - if provider.meta().id == embedding_provider_id: - return provider - raise ValueError( - f"未找到配置的 Embedding Provider: {embedding_provider_id}" - ) - - # 使用第一个可用的 provider - return embedding_providers[0] - - # ===== 知识库操作 ===== - - async def create_kb( - self, - kb_name: str, - description: Optional[str] = None, - emoji: Optional[str] = None, - embedding_provider_id: Optional[str] = None, - rerank_provider_id: Optional[str] = None, - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - top_k_dense: Optional[int] = None, - top_k_sparse: Optional[int] = None, - top_m_final: Optional[int] = None, - enable_rerank: Optional[bool] = None, - ) -> KnowledgeBase: - """创建知识库 - - Args: - enable_rerank: 是否启用重排序。 - - 如果明确传入 True/False,则使用该值 - - 如果为 None,则根据是否有可用的 rerank provider 自动决定 - """ - # 智能决定 enable_rerank 的默认值 - if enable_rerank is None: - # 检查是否有可用的 rerank provider - has_rerank_provider = ( - self.provider_manager - and hasattr(self.provider_manager, "rerank_provider_insts") - and len(self.provider_manager.rerank_provider_insts) > 0 - ) - enable_rerank = has_rerank_provider - - kb = KnowledgeBase( - kb_name=kb_name, - description=description, - emoji=emoji or "📚", - embedding_provider_id=embedding_provider_id, - rerank_provider_id=rerank_provider_id, - chunk_size=chunk_size if chunk_size is not None else 512, - chunk_overlap=chunk_overlap if chunk_overlap is not None else 50, - top_k_dense=top_k_dense if top_k_dense is not None else 50, - top_k_sparse=top_k_sparse if top_k_sparse is not None else 50, - top_m_final=top_m_final if top_m_final is not None else 5, - enable_rerank=enable_rerank, - ) - async with self.db.get_db() as session: - session.add(kb) - await session.commit() - await session.refresh(kb) - return kb - - async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]: - """获取知识库""" - async with self.db.get_db() as session: - stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]: - """列出所有知识库""" - async with self.db.get_db() as session: - stmt = ( - select(KnowledgeBase) - .offset(offset) - .limit(limit) - .order_by(KnowledgeBase.created_at.desc()) - ) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def update_kb( - self, - kb_id: str, - kb_name: Optional[str] = None, - description: Optional[str] = None, - emoji: Optional[str] = None, - embedding_provider_id: Optional[str] = None, - rerank_provider_id: Optional[str] = None, - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - top_k_dense: Optional[int] = None, - top_k_sparse: Optional[int] = None, - top_m_final: Optional[int] = None, - enable_rerank: Optional[bool] = None, - ) -> Optional[KnowledgeBase]: - """更新知识库""" - async with self.db.get_db() as session: - stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) - result = await session.execute(stmt) - kb = result.scalar_one_or_none() - if not kb: - return None - - if kb_name is not None: - kb.kb_name = kb_name - if description is not None: - kb.description = description - if emoji is not None: - kb.emoji = emoji - if embedding_provider_id is not None: - kb.embedding_provider_id = embedding_provider_id - if rerank_provider_id is not None: - kb.rerank_provider_id = rerank_provider_id - if chunk_size is not None: - kb.chunk_size = chunk_size - if chunk_overlap is not None: - kb.chunk_overlap = chunk_overlap - if top_k_dense is not None: - kb.top_k_dense = top_k_dense - if top_k_sparse is not None: - kb.top_k_sparse = top_k_sparse - if top_m_final is not None: - kb.top_m_final = top_m_final - if enable_rerank is not None: - kb.enable_rerank = enable_rerank - - await session.commit() - await session.refresh(kb) - return kb - - async def delete_kb(self, kb_id: str) -> bool: - """删除知识库(级联删除所有文档和资源)""" - # 1. 获取所有文档 - from astrbot.core.knowledge_base.manager_ops import KBManagerOps - - ops = KBManagerOps(self) - docs = await ops.list_documents(kb_id) - - # 2. 删除所有文档(包括文件和向量) - for doc in docs: - await ops.delete_document(doc.doc_id) - - # 3. 删除向量数据库 - await self.vec_db_factory.delete_vec_db(kb_id) - - # 4. 删除知识库记录 - async with self.db.get_db() as session: - stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id) - result = await session.execute(stmt) - kb = result.scalar_one_or_none() - if not kb: - return False - - await session.delete(kb) - await session.commit() - - return True - - # ===== 文档上传 ===== - - async def upload_document( - self, - kb_id: str, - file_name: str, - file_content: bytes, - file_type: str, - ) -> KBDocument: - """上传并处理文档(带原子性保证和失败清理) - - 流程: - 1. 保存原始文件 - 2. 解析文档内容 - 3. 提取多媒体资源 - 4. 分块处理 - 5. 生成向量并存储 - 6. 保存元数据(事务) - 7. 更新统计 - """ - doc_id = str(uuid.uuid4()) - file_path = None - media_paths = [] - vec_doc_ids = [] - - try: - # 1. 保存原始文件 - file_path = self.files_path / kb_id / f"{doc_id}.{file_type}" - file_path.parent.mkdir(parents=True, exist_ok=True) - - async with aiofiles.open(file_path, "wb") as f: - await f.write(file_content) - - # 2. 解析文档 - parser = self.parsers.get(file_type) - if not parser: - raise ValueError(f"不支持的文件类型: {file_type}") - - parse_result = await parser.parse(file_content, file_name) - text_content = parse_result.text - media_items = parse_result.media - - # 3. 保存多媒体资源 - from astrbot.core.knowledge_base.manager_ops import KBManagerOps - - ops = KBManagerOps(self) - saved_media = [] - for media_item in media_items: - media = await ops._save_media( - kb_id=kb_id, - doc_id=doc_id, - media_type=media_item.media_type, - file_name=media_item.file_name, - content=media_item.content, - mime_type=media_item.mime_type, - ) - saved_media.append(media) - media_paths.append(Path(media.file_path)) - - # 4. 文档分块 - chunks_text = await self.chunker.chunk(text_content) - - # 5. 获取 Embedding Provider 和向量数据库 - embedding_provider = await self._get_embedding_provider_for_kb(kb_id) - vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider) - - # 6. 生成向量并存储 - saved_chunks = [] - for idx, chunk_text in enumerate(chunks_text): - # 存储到向量数据库 - vec_doc_id = await vec_db.insert( - content=chunk_text, - metadata={ - "kb_id": kb_id, - "doc_id": doc_id, - "chunk_index": idx, - }, - ) - vec_doc_ids.append(str(vec_doc_id)) - - # 保存块元数据 - chunk = KBChunk( - doc_id=doc_id, - kb_id=kb_id, - chunk_index=idx, - content=chunk_text, - char_count=len(chunk_text), - vec_doc_id=str(vec_doc_id), - ) - saved_chunks.append(chunk) - - # 7. 保存文档元数据(事务) - doc = KBDocument( - doc_id=doc_id, - kb_id=kb_id, - doc_name=file_name, - file_type=file_type, - file_size=len(file_content), - file_path=str(file_path), - chunk_count=len(saved_chunks), - media_count=len(saved_media), - ) - - async with self.db.get_db() as session: - async with session.begin(): - session.add(doc) - for chunk in saved_chunks: - session.add(chunk) - for media in saved_media: - session.add(media) - await session.commit() - - await session.refresh(doc) - - # 8. 更新知识库统计 - await self._update_kb_stats(kb_id) - - return doc - - except Exception as e: - # 失败清理:删除已创建的资源 - from astrbot.core import logger - - logger.error(f"文档上传失败,开始清理资源: {e}") - - # 获取知识库的向量数据库 - try: - embedding_provider = await self._get_embedding_provider_for_kb(kb_id) - vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider) - - # 清理向量数据库 - for vec_id in vec_doc_ids: - try: - await vec_db.delete(vec_id) - except Exception as ve: - logger.warning(f"清理向量失败 {vec_id}: {ve}") - except Exception as vfe: - logger.error(f"获取向量数据库失败: {vfe}") - - # 清理多媒体文件 - for media_path in media_paths: - try: - if media_path.exists(): - media_path.unlink() - except Exception as me: - logger.warning(f"清理多媒体文件失败 {media_path}: {me}") - - # 清理文档文件 - if file_path and file_path.exists(): - try: - file_path.unlink() - except Exception as fe: - logger.warning(f"清理文档文件失败 {file_path}: {fe}") - - # 重新抛出原始异常 - raise - - # ===== 统计更新 ===== - - async def _update_kb_stats(self, kb_id: str): - """更新知识库统计信息(事务中执行)""" - async with self.db.get_db() as session: - async with session.begin(): - # 统计文档数(在事务中查询) - doc_count = ( - await session.scalar( - select(func.count(KBDocument.id)).where( - KBDocument.kb_id == kb_id - ) - ) - or 0 - ) - - # 统计块数(在事务中查询) - chunk_count = ( - await session.scalar( - select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id) - ) - or 0 - ) - - # 更新知识库(在同一事务中) - await session.execute( - update(KnowledgeBase) - .where(KnowledgeBase.kb_id == kb_id) - .values(doc_count=doc_count, chunk_count=chunk_count) - ) - - await session.commit() diff --git a/astrbot/core/knowledge_base/manager_ops.py b/astrbot/core/knowledge_base/manager_ops.py deleted file mode 100644 index 45068926..00000000 --- a/astrbot/core/knowledge_base/manager_ops.py +++ /dev/null @@ -1,323 +0,0 @@ -"""知识库管理器辅助操作 - -该模块提供文档、块和多媒体的管理操作。 -""" - -import uuid -from pathlib import Path -from typing import TYPE_CHECKING - -import aiofiles -from sqlalchemy import delete, func, select - -from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia - -if TYPE_CHECKING: - from astrbot.core.knowledge_base.manager import KBManager - - -class KBManagerOps: - """知识库管理器辅助操作类 - - 职责: - - 文档管理操作 - - 块管理操作 - - 多媒体管理操作 - """ - - def __init__(self, manager: "KBManager"): - self.manager = manager - self.db = manager.db - self.vec_db_factory = manager.vec_db_factory - self.media_path = manager.media_path - self.files_path = manager.files_path - - # ===== 文档操作 ===== - - async def list_documents( - self, kb_id: str, offset: int = 0, limit: int = 100 - ) -> list[KBDocument]: - """列出知识库的所有文档""" - async with self.db.get_db() as session: - stmt = ( - select(KBDocument) - .where(KBDocument.kb_id == kb_id) - .offset(offset) - .limit(limit) - .order_by(KBDocument.created_at.desc()) - ) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def get_document(self, doc_id: str) -> KBDocument | None: - """获取文档详情""" - async with self.db.get_db() as session: - stmt = select(KBDocument).where(KBDocument.doc_id == doc_id) - result = await session.execute(stmt) - return result.scalar_one_or_none() - - async def delete_document(self, doc_id: str) -> bool: - """删除文档(级联删除块、多媒体、向量) - - 采用三阶段删除策略: - 1. 删除向量数据库中的向量(允许部分失败) - 2. 删除SQL数据库中的记录(事务保证原子性) - 3. 删除文件系统中的文件(失败不影响数据一致性) - """ - from astrbot.core import logger - - # 0. 获取文档信息 - doc = await self.get_document(doc_id) - if not doc: - return False - - # 收集所有需要删除的资源 - chunks = await self.list_chunks(doc_id) - media_list = await self.list_media(doc_id) - - # 获取知识库的向量数据库 - embedding_provider = await self.manager._get_embedding_provider_for_kb( - doc.kb_id - ) - vec_db = await self.vec_db_factory.get_vec_db(doc.kb_id, embedding_provider) - - # ===== 第一阶段: 删除向量(可重试) ===== - vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks] - deleted_vec_ids = [] - failed_vec_ids = [] - - for vec_id in vec_ids_to_delete: - try: - await vec_db.delete(vec_id) - deleted_vec_ids.append(vec_id) - except Exception as e: - logger.error(f"删除向量失败: {vec_id}, {e}") - failed_vec_ids.append(vec_id) - - # 如果向量删除失败过多(超过50%),中止操作 - if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5: - logger.error( - f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除" - ) - return False - - # 记录部分失败但继续执行 - if failed_vec_ids: - logger.warning( - f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作" - ) - - # ===== 第二阶段: 删除数据库记录(事务) ===== - async with self.db.get_db() as session: - async with session.begin(): - # 删除块记录 - await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id)) - - # 删除多媒体记录 - await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id)) - - # 删除文档记录 - await session.execute( - delete(KBDocument).where(KBDocument.doc_id == doc_id) - ) - - await session.commit() - - # ===== 第三阶段: 删除文件(失败不影响) ===== - # 删除多媒体文件 - for media in media_list: - try: - media_path = Path(media.file_path) - if media_path.exists(): - media_path.unlink() - except Exception as e: - logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}") - - # 删除文档文件 - try: - file_path = Path(doc.file_path) - if file_path.exists(): - file_path.unlink() - except Exception as e: - logger.warning(f"删除文档文件失败: {doc.file_path}, {e}") - - # ===== 更新统计 ===== - await self.manager._update_kb_stats(doc.kb_id) - - return True - - # ===== 块操作 ===== - - async def list_chunks(self, doc_id: str) -> list[KBChunk]: - """列出文档的所有块""" - async with self.db.get_db() as session: - stmt = ( - select(KBChunk) - .where(KBChunk.doc_id == doc_id) - .order_by(KBChunk.chunk_index) - ) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def delete_chunk(self, chunk_id: str) -> bool: - """删除单个块 - - 流程: - 1. 查询块信息 - 2. 删除向量 - 3. 删除数据库记录 - 4. 更新文档统计 - """ - from astrbot.core import logger - - # 1. 查询块信息 - async with self.db.get_db() as session: - stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id) - result = await session.execute(stmt) - chunk = result.scalar_one_or_none() - if not chunk: - return False - - doc_id = chunk.doc_id - kb_id = chunk.kb_id - vec_doc_id = chunk.vec_doc_id - - # 2. 获取知识库的向量数据库并删除向量 - try: - embedding_provider = await self.manager._get_embedding_provider_for_kb( - kb_id - ) - vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider) - await vec_db.delete(vec_doc_id) - except Exception as e: - logger.error(f"删除向量失败: {vec_doc_id}, {e}") - return False - - # 3. 删除数据库记录 - async with self.db.get_db() as session: - async with session.begin(): - await session.execute( - delete(KBChunk).where(KBChunk.chunk_id == chunk_id) - ) - await session.commit() - - # 4. 更新文档统计 - await self._update_doc_stats(doc_id) - - return True - - # ===== 多媒体操作 ===== - - async def list_media(self, doc_id: str) -> list[KBMedia]: - """列出文档的所有多媒体资源""" - async with self.db.get_db() as session: - stmt = select(KBMedia).where(KBMedia.doc_id == doc_id) - result = await session.execute(stmt) - return list(result.scalars().all()) - - async def delete_media(self, media_id: str) -> bool: - """删除多媒体资源 - - 流程: - 1. 查询媒体信息 - 2. 删除数据库记录 - 3. 删除文件(失败不影响) - 4. 更新文档统计 - """ - from astrbot.core import logger - - # 1. 查询媒体信息 - async with self.db.get_db() as session: - stmt = select(KBMedia).where(KBMedia.media_id == media_id) - result = await session.execute(stmt) - media = result.scalar_one_or_none() - if not media: - return False - - doc_id = media.doc_id - file_path_str = media.file_path - - # 2. 删除数据库记录 - async with self.db.get_db() as session: - async with session.begin(): - await session.execute( - delete(KBMedia).where(KBMedia.media_id == media_id) - ) - await session.commit() - - # 3. 删除文件(失败不影响) - try: - media_path = Path(file_path_str) - if media_path.exists(): - media_path.unlink() - except Exception as e: - logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}") - - # 4. 更新文档统计 - await self._update_doc_stats(doc_id) - - return True - - # ===== 内部辅助方法 ===== - - async def _save_media( - self, - kb_id: str, - doc_id: str, - media_type: str, - file_name: str, - content: bytes, - mime_type: str, - ) -> KBMedia: - """保存多媒体资源""" - media_id = str(uuid.uuid4()) - ext = Path(file_name).suffix - - # 保存文件 - file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}" - file_path.parent.mkdir(parents=True, exist_ok=True) - - async with aiofiles.open(file_path, "wb") as f: - await f.write(content) - - # 创建记录 - media = KBMedia( - media_id=media_id, - doc_id=doc_id, - kb_id=kb_id, - media_type=media_type, - file_name=file_name, - file_path=str(file_path), - file_size=len(content), - mime_type=mime_type, - ) - - return media - - async def _update_doc_stats(self, doc_id: str): - """更新文档统计信息(事务中执行)""" - async with self.db.get_db() as session: - async with session.begin(): - # 统计块数 - chunk_count = ( - await session.scalar( - select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id) - ) - ) or 0 - - # 统计多媒体数 - media_count = ( - await session.scalar( - select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id) - ) - ) or 0 - - # 更新文档 - doc = await session.scalar( - select(KBDocument).where(KBDocument.doc_id == doc_id) - ) - if doc: - doc.chunk_count = chunk_count - doc.media_count = media_count - - await session.commit() diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index 4777ac47..aff629b9 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -1,22 +1,8 @@ -"""知识库管理功能的数据模型定义 - -该模块定义了知识库系统所需的数据模型,包括: -- KnowledgeBase: 知识库表 (存储在独立的 kb.db) -- KBDocument: 文档表 (存储在独立的 kb.db) -- KBChunk: 文档块表 (存储在独立的 kb.db) -- KBMedia: 多媒体资源表 (存储在独立的 kb.db) -- KBSessionConfig: 会话配置表 (存储在独立的 kb.db) - -注意: -- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db) -- 与主数据库 (astrbot.db) 完全解耦 -""" - import uuid from datetime import datetime, timezone from typing import Optional -from sqlmodel import Field, SQLModel, Text +from sqlmodel import Field, SQLModel, Text, UniqueConstraint class KnowledgeBase(SQLModel, table=True): @@ -49,7 +35,6 @@ class KnowledgeBase(SQLModel, table=True): top_k_dense: Optional[int] = Field(default=50, nullable=True) top_k_sparse: Optional[int] = Field(default=50, nullable=True) top_m_final: Optional[int] = Field(default=5, nullable=True) - enable_rerank: Optional[bool] = Field(default=False, nullable=True) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), @@ -58,6 +43,13 @@ class KnowledgeBase(SQLModel, table=True): doc_count: int = Field(default=0, nullable=False) chunk_count: int = Field(default=0, nullable=False) + __table_args__ = ( + UniqueConstraint( + "kb_name", + name="uix_kb_name", + ), + ) + class KBDocument(SQLModel, table=True): """文档表 diff --git a/astrbot/core/knowledge_base/parsers/pdf_parser.py b/astrbot/core/knowledge_base/parsers/pdf_parser.py index 8bb1dea6..fca62687 100644 --- a/astrbot/core/knowledge_base/parsers/pdf_parser.py +++ b/astrbot/core/knowledge_base/parsers/pdf_parser.py @@ -51,10 +51,10 @@ class PDFParser(BaseParser): continue resources = page["/Resources"] - if not resources or "/XObject" not in resources: + if not resources or "/XObject" not in resources: # type: ignore continue - xobjects = resources["/XObject"].get_object() + xobjects = resources["/XObject"].get_object() # type: ignore if not xobjects: continue diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 1b500a1e..4be5c422 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -6,12 +6,15 @@ from dataclasses import dataclass from typing import List -from astrbot.core.knowledge_base.database import KBDatabase +from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever from astrbot.core.provider.provider import RerankProvider -from astrbot.core.db.vec_db.base import BaseVecDB, Result -from ..vec_db_factory import VecDBFactory +from astrbot.core.db.vec_db.base import Result +from astrbot.core.db.vec_db.faiss_impl import FaissVecDB +from ..kb_helper import KBHelper +from astrbot import logger + @dataclass class RetrievalResult: @@ -37,10 +40,9 @@ class RetrievalManager: def __init__( self, - vec_db_factory, # VecDBFactory sparse_retriever: SparseRetriever, rank_fusion: RankFusion, - kb_db: KBDatabase, + kb_db: KBSQLiteDatabase, ): """初始化检索管理器 @@ -50,21 +52,16 @@ class RetrievalManager: rank_fusion: 结果融合器 kb_db: 知识库数据库实例 """ - self.vec_db_factory = vec_db_factory self.sparse_retriever = sparse_retriever self.rank_fusion = rank_fusion self.kb_db = kb_db async def retrieve( self, - vec_db_factory: VecDBFactory, query: str, kb_ids: List[str], - top_k_dense: int = 50, - top_k_sparse: int = 50, - top_n_fusion: int = 20, + kb_id_helper_map: dict[str, KBHelper], top_m_final: int = 5, - rerank_provider: RerankProvider | None = None, ) -> List[RetrievalResult]: """混合检索 @@ -77,35 +74,52 @@ class RetrievalManager: Args: query: 查询文本 kb_ids: 知识库 ID 列表 - top_k_dense: 稠密检索返回数量 - top_k_sparse: 稀疏检索返回数量 - top_n_fusion: 融合后返回数量 top_m_final: 最终返回数量 enable_rerank: 是否启用 Rerank Returns: List[RetrievalResult]: 检索结果列表 """ + if not kb_ids: + return [] + + kb_options: dict = {} + new_kb_ids = [] + for kb_id in kb_ids: + kb_helper = kb_id_helper_map.get(kb_id) + if kb_helper: + kb = kb_helper.kb + kb_options[kb_id] = { + "top_k_dense": kb.top_k_dense or 50, + "top_k_sparse": kb.top_k_sparse or 50, + "top_m_final": kb.top_m_final or 5, + "vec_db": kb_helper.vec_db, + } + new_kb_ids.append(kb_id) + else: + logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索") + + kb_ids = new_kb_ids + # 1. 稠密检索 dense_results = await self._dense_retrieve( query=query, kb_ids=kb_ids, - top_k=top_k_dense, - vec_db=vec_db, + kb_options=kb_options, ) # 2. 稀疏检索 sparse_results = await self.sparse_retriever.retrieve( query=query, kb_ids=kb_ids, - top_k=top_k_sparse, + kb_options=kb_options, ) # 3. 结果融合 fused_results = await self.rank_fusion.fuse( dense_results=dense_results, sparse_results=sparse_results, - top_k=top_n_fusion, + top_k=kb_options.get("top_k_fusion", 20), ) # 4. 转换为 RetrievalResult (获取元数据) @@ -130,24 +144,27 @@ class RetrievalManager: ) # 5. Rerank - if rerank_provider and retrieval_results: + first_rerank = None + for kb_id in kb_ids: + vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] + if vec_db and vec_db.rerank_provider: + first_rerank = vec_db.rerank_provider + break + if first_rerank and retrieval_results: retrieval_results = await self._rerank( query=query, results=retrieval_results, top_k=top_m_final, - rerank_provider=rerank_provider, + rerank_provider=first_rerank, ) - else: - retrieval_results = retrieval_results[:top_m_final] - return retrieval_results + return retrieval_results[:top_m_final] async def _dense_retrieve( self, query: str, kb_ids: List[str], - top_k: int, - vec_db: BaseVecDB, + kb_options: dict, ): """稠密检索 (向量相似度) @@ -162,13 +179,17 @@ class RetrievalManager: List[Result]: 检索结果列表 """ all_results: list[Result] = [] - for kb_id in kb_ids: + if kb_id not in kb_options: + continue try: + vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] + dense_k = int(kb_options[kb_id]["top_k_dense"]) vec_results = await vec_db.retrieve( query=query, - top_k=top_k, - fetch_k=top_k * 2, + k=dense_k, + fetch_k=dense_k * 2, + rerank=False, # 稠密检索阶段不进行 rerank metadata_filters={"kb_id": kb_id}, ) @@ -181,7 +202,7 @@ class RetrievalManager: # 按相似度排序并返回 top_k all_results.sort(key=lambda x: x.similarity, reverse=True) - return all_results[:top_k] + return all_results[: len(all_results) // len(kb_ids)] async def _rerank( self, diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index b05fe1be..a5d5e255 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Dict, List from astrbot.core.db.vec_db.base import Result -from astrbot.core.knowledge_base.database import KBDatabase +from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult @@ -30,7 +30,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 15c20512..593ebe6d 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -8,7 +8,7 @@ from typing import List from rank_bm25 import BM25Okapi -from astrbot.core.knowledge_base.database import KBDatabase +from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase @dataclass @@ -30,7 +30,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBDatabase): + def __init__(self, kb_db: KBSQLiteDatabase): """初始化稀疏检索器 Args: @@ -43,14 +43,14 @@ class SparseRetriever: self, query: str, kb_ids: List[str], - top_k: int = 50, + kb_options: dict, ) -> List[SparseResult]: """执行稀疏检索 Args: query: 查询文本 kb_ids: 知识库 ID 列表 - top_k: 返回结果数量 + kb_options: 每个知识库的检索选项 Returns: List[SparseResult]: 检索结果列表 @@ -87,4 +87,4 @@ class SparseRetriever: ) results.sort(key=lambda x: x.score, reverse=True) - return results[:top_k] + return results[: len(results) // len(kb_ids)] diff --git a/astrbot/core/knowledge_base/vec_db_factory.py b/astrbot/core/knowledge_base/vec_db_factory.py deleted file mode 100644 index ba2187f4..00000000 --- a/astrbot/core/knowledge_base/vec_db_factory.py +++ /dev/null @@ -1,161 +0,0 @@ -"""向量数据库工厂 - -负责为每个知识库创建和管理独立的向量数据库实例。 - -架构说明: -- 每个知识库拥有独立的向量数据库实例 -- 向量数据库文件以 kb_id 命名 -- 工厂类负责实例的创建、缓存和生命周期管理 -""" - -from pathlib import Path -from typing import Dict, Optional - -from astrbot.core import logger -from astrbot.core.db.vec_db.base import BaseVecDB -from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB -from astrbot.core.provider.provider import EmbeddingProvider - - -class VecDBFactory: - """向量数据库工厂 - - 职责: - - 为每个知识库创建独立的向量数据库实例 - - 缓存已创建的实例以提高性能 - - 管理向量数据库的生命周期 - """ - - def __init__( - self, - storage_base_path: str, - ): - """初始化向量数据库工厂 - - Args: - storage_base_path: 向量数据库存储基础路径 - """ - self.storage_base_path = Path(storage_base_path) - self._instances: Dict[str, BaseVecDB] = {} - - # 确保基础路径存在 - self.storage_base_path.mkdir(parents=True, exist_ok=True) - - async def get_vec_db( - self, kb_id: str, embedding_provider: EmbeddingProvider - ) -> BaseVecDB: - """获取或创建指定知识库的向量数据库实例 - - Args: - kb_id: 知识库 ID - embedding_provider: Embedding Provider 实例 - - Returns: - BaseVecDB: 向量数据库实例 - """ - # 如果已经创建过,直接返回缓存的实例 - if kb_id in self._instances: - return self._instances[kb_id] - - # 创建新实例 - vec_db = await self._create_vec_db(kb_id, embedding_provider) - self._instances[kb_id] = vec_db - - logger.debug(f"创建知识库 {kb_id} 的向量数据库实例") - - return vec_db - - async def _create_vec_db( - self, kb_id: str, embedding_provider: EmbeddingProvider - ) -> BaseVecDB: - """创建向量数据库实例 - - Args: - kb_id: 知识库 ID - embedding_provider: Embedding Provider 实例 - - Returns: - BaseVecDB: 向量数据库实例 - """ - # 为每个知识库创建独立的存储路径 - kb_storage_path = self.storage_base_path / kb_id - kb_storage_path.mkdir(parents=True, exist_ok=True) - - doc_store_path = str(kb_storage_path / "documents.db") - index_store_path = str(kb_storage_path / "index.faiss") - - vec_db = FaissVecDB( - doc_store_path=doc_store_path, - index_store_path=index_store_path, - embedding_provider=embedding_provider, - ) - - await vec_db.initialize() - - return vec_db - - async def delete_vec_db(self, kb_id: str) -> bool: - """删除指定知识库的向量数据库 - - Args: - kb_id: 知识库 ID - - Returns: - bool: 是否删除成功 - """ - # 关闭并移除缓存的实例 - if kb_id in self._instances: - try: - await self._instances[kb_id].close() - except Exception as e: - logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}") - - del self._instances[kb_id] - - # 删除文件系统中的向量数据库文件 - kb_storage_path = self.storage_base_path / kb_id - if kb_storage_path.exists(): - try: - import shutil - - shutil.rmtree(kb_storage_path) - logger.info(f"已删除知识库 {kb_id} 的向量数据库文件") - return True - except Exception as e: - logger.error(f"删除向量数据库文件失败 ({kb_id}): {e}") - return False - - return True - - async def close_all(self): - """关闭所有向量数据库实例""" - for kb_id, vec_db in list(self._instances.items()): - try: - await vec_db.close() - logger.debug(f"已关闭知识库 {kb_id} 的向量数据库") - except Exception as e: - logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}") - - self._instances.clear() - - def has_instance(self, kb_id: str) -> bool: - """检查是否已创建指定知识库的向量数据库实例 - - Args: - kb_id: 知识库 ID - - Returns: - bool: 是否已创建实例 - """ - return kb_id in self._instances - - def get_cached_instance(self, kb_id: str) -> Optional[BaseVecDB]: - """获取已缓存的向量数据库实例(不创建新实例) - - Args: - kb_id: 知识库 ID - - Returns: - Optional[BaseVecDB]: 向量数据库实例,如果不存在则返回 None - """ - return self._instances.get(kb_id) diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index d870c398..bf40b3ca 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -15,13 +15,15 @@ async def inject_kb_context( p_ctx: Pipeline context req: Provider request """ - kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector() - if not kb_injector: + kb_mgr = p_ctx.plugin_manager.context.kb_manager + kb_names = p_ctx.astrbot_config.get("kb_names", []) + + if not kb_names: return - kb_context = await kb_injector.retrieve_and_inject( - unified_msg_origin=umo, + + kb_context = await kb_mgr.retrieve( query=req.prompt, - top_k=top_k, + kb_names=kb_names, ) if not kb_context: return diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index fb445894..0229f4db 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -19,7 +19,7 @@ from astrbot.core.platform import Platform from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager -from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager +from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from .star import star_registry, StarMetadata, star_map from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index e4736a24..7a70b43d 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -1,5 +1,7 @@ """知识库管理 API 路由""" +import uuid +import aiofiles import os import traceback from quart import request @@ -29,8 +31,7 @@ class KnowledgeBaseRoute(Route): # 注册路由 self.routes = { # 系统管理 - "/kb/status": ("GET", self.get_kb_status), - "/kb/initialize": ("POST", self.initialize_kb), + # "/kb/status": ("GET", self.get_kb_status), # 知识库管理 "/kb/list": ("GET", self.list_kbs), "/kb/create": ("POST", self.create_kb), @@ -42,188 +43,21 @@ class KnowledgeBaseRoute(Route): "/kb/document/list": ("GET", self.list_documents), "/kb/document/upload": ("POST", self.upload_document), "/kb/document/get": ("GET", self.get_document), - "/kb/document/delete": ("POST", self.delete_document), - # 块管理 + # "/kb/document/delete": ("POST", self.delete_document), + # # 块管理 "/kb/chunk/list": ("GET", self.list_chunks), - "/kb/chunk/get": ("GET", self.get_chunk), - "/kb/chunk/delete": ("POST", self.delete_chunk), - # 多媒体管理 - "/kb/media/list": ("GET", self.list_media), - "/kb/media/delete": ("POST", self.delete_media), + # "/kb/chunk/get": ("GET", self.get_chunk), + # "/kb/chunk/delete": ("POST", self.delete_chunk), + # # 多媒体管理 + # "/kb/media/list": ("GET", self.list_media), + # "/kb/media/delete": ("POST", self.delete_media), # 检索 "/kb/retrieve": ("POST", self.retrieve), - # 会话配置 - "/kb/session/config/get": ("GET", self.get_session_config), - "/kb/session/config/set": ("POST", self.set_session_config), - "/kb/session/config/delete": ("POST", self.delete_session_config), - "/kb/session/config/list": ("GET", self.list_session_configs), - "/kb/session/config/list_by_kb": ("GET", self.list_sessions_by_kb), } self.register_routes() def _get_kb_manager(self): - """获取知识库管理器实例""" - if not self.kb_manager: - if not hasattr(self.core_lifecycle, "kb_manager"): - raise ValueError("知识库模块未启用或未初始化") - # 从 KnowledgeBaseManager (lifecycle 管理器) 获取实际的组件 - kb_lifecycle = self.core_lifecycle.kb_manager - if not kb_lifecycle.is_initialized: - raise ValueError("知识库模块未完成初始化") - - self.kb_manager = kb_lifecycle.kb_manager - self.kb_db = kb_lifecycle.kb_database - self.retrieval_manager = kb_lifecycle.retrieval_manager - return self.kb_manager - - # ===== 系统管理 API ===== - - async def get_kb_status(self): - """获取知识库模块状态 - - 返回知识库模块是否已启用和初始化 - """ - try: - if not hasattr(self.core_lifecycle, "kb_manager"): - return ( - Response() - .ok( - { - "enabled": False, - "initialized": False, - "message": "知识库模块未启用", - } - ) - .__dict__ - ) - - kb_lifecycle = self.core_lifecycle.kb_manager - config = kb_lifecycle.config - - # 检查是否启用 - enabled = config.get("enabled", False) - if not enabled: - return ( - Response() - .ok( - { - "enabled": False, - "initialized": False, - "message": "知识库功能未在配置中启用", - } - ) - .__dict__ - ) - - # 检查是否初始化 - initialized = kb_lifecycle.is_initialized - if not initialized: - # 检查是否有embedding provider - has_embedding = ( - len(kb_lifecycle.provider_manager.embedding_provider_insts) > 0 - ) - if not has_embedding: - return ( - Response() - .ok( - { - "enabled": True, - "initialized": False, - "message": "未配置 Embedding Provider,请先在提供商管理中添加支持 embedding 的模型", - } - ) - .__dict__ - ) - else: - return ( - Response() - .ok( - { - "enabled": True, - "initialized": False, - "message": "知识库模块未初始化,请点击初始化按钮", - } - ) - .__dict__ - ) - - return ( - Response() - .ok( - { - "enabled": True, - "initialized": True, - "message": "知识库模块运行正常", - } - ) - .__dict__ - ) - - except Exception as e: - logger.error(f"获取知识库状态失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库状态失败: {str(e)}").__dict__ - - async def initialize_kb(self): - """初始化或重新初始化知识库模块 - - 用于在运行时动态初始化知识库模块 - """ - try: - if not hasattr(self.core_lifecycle, "kb_manager"): - return Response().error("知识库模块未启用").__dict__ - - kb_lifecycle = self.core_lifecycle.kb_manager - config = kb_lifecycle.config - - # 检查是否启用 - enabled = config.get("enabled", False) - if not enabled: - return ( - Response() - .error( - "知识库功能未在配置中启用,请在配置文件中设置 knowledge_base.enabled = true" - ) - .__dict__ - ) - - # 尝试初始化 - logger.info("收到知识库初始化请求,正在初始化...") - success = await kb_lifecycle.reinitialize() - - if success: - # 清除缓存的实例,强制下次重新获取 - self.kb_manager = None - self.kb_db = None - self.retrieval_manager = None - - return Response().ok(message="知识库模块初始化成功").__dict__ - else: - # 检查失败原因 - has_embedding = ( - len(kb_lifecycle.provider_manager.embedding_provider_insts) > 0 - ) - if not has_embedding: - return ( - Response() - .error( - "初始化失败:未配置 Embedding Provider,请先在提供商管理中添加支持 embedding 的模型" - ) - .__dict__ - ) - else: - return ( - Response() - .error("知识库模块初始化失败,请查看后端日志获取详细信息") - .__dict__ - ) - - except Exception as e: - logger.error(f"初始化知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"初始化知识库失败: {str(e)}").__dict__ - - # ===== 知识库管理 API ===== + return self.core_lifecycle.kb_manager async def list_kbs(self): """获取知识库列表 @@ -237,23 +71,8 @@ class KnowledgeBaseRoute(Route): kb_manager = self._get_kb_manager() page = request.args.get("page", 1, type=int) page_size = request.args.get("page_size", 20, type=int) - refresh_stats = request.args.get("refresh_stats", "false").lower() == "true" - # 转换为 offset 和 limit - offset = (page - 1) * page_size - limit = page_size - - kbs = await kb_manager.list_kbs(offset=offset, limit=limit) - - # 如果需要刷新统计信息 - if refresh_stats: - for kb in kbs: - try: - await kb_manager._update_kb_stats(kb.kb_id) - except Exception as e: - logger.warning(f"刷新知识库 {kb.kb_id} 统计信息失败: {e}") - # 刷新后重新查询以获取最新数据 - kbs = await kb_manager.list_kbs(offset=offset, limit=limit) + kbs = await kb_manager.list_kbs() # 转换为字典列表 kb_list = [] @@ -267,15 +86,11 @@ class KnowledgeBaseRoute(Route): "rerank_provider_id": kb.rerank_provider_id, "doc_count": kb.doc_count, "chunk_count": kb.chunk_count, - # 添加配置参数 "chunk_size": kb.chunk_size or 512, "chunk_overlap": kb.chunk_overlap or 50, "top_k_dense": kb.top_k_dense or 50, "top_k_sparse": kb.top_k_sparse or 50, "top_m_final": kb.top_m_final or 5, - "enable_rerank": kb.enable_rerank - if kb.enable_rerank is not None - else True, "created_at": kb.created_at.isoformat(), "updated_at": kb.updated_at.isoformat(), } @@ -286,7 +101,6 @@ class KnowledgeBaseRoute(Route): .ok({"items": kb_list, "page": page, "page_size": page_size}) .__dict__ ) - except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: @@ -308,40 +122,25 @@ class KnowledgeBaseRoute(Route): - top_k_dense: 密集检索数量 (可选, 默认50) - top_k_sparse: 稀疏检索数量 (可选, 默认50) - top_m_final: 最终返回数量 (可选, 默认5) - - enable_rerank: 是否启用Rerank (可选, 默认True) """ try: kb_manager = self._get_kb_manager() data = await request.json - kb_name = data.get("kb_name") if not kb_name: return Response().error("知识库名称不能为空").__dict__ description = data.get("description") emoji = data.get("emoji") - - # 提取 provider ID (前端可能传入完整对象或直接传入ID字符串) - embedding_provider = data.get("embedding_provider_id") - if isinstance(embedding_provider, dict): - embedding_provider_id = embedding_provider.get("id") - else: - embedding_provider_id = embedding_provider - - rerank_provider = data.get("rerank_provider_id") - if isinstance(rerank_provider, dict): - rerank_provider_id = rerank_provider.get("id") - else: - rerank_provider_id = rerank_provider - + embedding_provider_id = data.get("embedding_provider_id") + rerank_provider_id = data.get("rerank_provider_id") chunk_size = data.get("chunk_size") chunk_overlap = data.get("chunk_overlap") top_k_dense = data.get("top_k_dense") top_k_sparse = data.get("top_k_sparse") top_m_final = data.get("top_m_final") - enable_rerank = data.get("enable_rerank") - kb = await kb_manager.create_kb( + kb_helper = await kb_manager.create_kb( kb_name=kb_name, description=description, emoji=emoji, @@ -352,8 +151,8 @@ class KnowledgeBaseRoute(Route): top_k_dense=top_k_dense, top_k_sparse=top_k_sparse, top_m_final=top_m_final, - enable_rerank=enable_rerank, ) + kb = kb_helper.kb kb_dict = { "kb_id": kb.kb_id, @@ -369,9 +168,6 @@ class KnowledgeBaseRoute(Route): "top_k_dense": kb.top_k_dense or 50, "top_k_sparse": kb.top_k_sparse or 50, "top_m_final": kb.top_m_final or 5, - "enable_rerank": kb.enable_rerank - if kb.enable_rerank is not None - else True, "created_at": kb.created_at.isoformat(), "updated_at": kb.updated_at.isoformat(), } @@ -397,9 +193,10 @@ class KnowledgeBaseRoute(Route): if not kb_id: return Response().error("缺少参数 kb_id").__dict__ - kb = await kb_manager.get_kb(kb_id) - if not kb: + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: return Response().error("知识库不存在").__dict__ + kb = kb_helper.kb kb_dict = { "kb_id": kb.kb_id, @@ -451,26 +248,13 @@ class KnowledgeBaseRoute(Route): kb_name = data.get("kb_name") description = data.get("description") emoji = data.get("emoji") - - # 提取 provider ID (前端可能传入完整对象或直接传入ID字符串) - embedding_provider = data.get("embedding_provider_id") - if isinstance(embedding_provider, dict): - embedding_provider_id = embedding_provider.get("id") - else: - embedding_provider_id = embedding_provider - - rerank_provider = data.get("rerank_provider_id") - if isinstance(rerank_provider, dict): - rerank_provider_id = rerank_provider.get("id") - else: - rerank_provider_id = rerank_provider - + embedding_provider_id = data.get("embedding_provider_id") + rerank_provider_id = data.get("rerank_provider_id") chunk_size = data.get("chunk_size") chunk_overlap = data.get("chunk_overlap") top_k_dense = data.get("top_k_dense") top_k_sparse = data.get("top_k_sparse") top_m_final = data.get("top_m_final") - enable_rerank = data.get("enable_rerank") # 检查是否至少提供了一个更新字段 if all( @@ -486,12 +270,11 @@ class KnowledgeBaseRoute(Route): top_k_dense, top_k_sparse, top_m_final, - enable_rerank, ] ): return Response().error("至少需要提供一个更新字段").__dict__ - kb = await kb_manager.update_kb( + kb_helper = await kb_manager.update_kb( kb_id=kb_id, kb_name=kb_name, description=description, @@ -503,12 +286,13 @@ class KnowledgeBaseRoute(Route): top_k_dense=top_k_dense, top_k_sparse=top_k_sparse, top_m_final=top_m_final, - enable_rerank=enable_rerank, ) - if not kb: + if not kb_helper: return Response().error("知识库不存在").__dict__ + kb = kb_helper.kb + kb_dict = { "kb_id": kb.kb_id, "kb_name": kb.kb_name, @@ -522,10 +306,6 @@ class KnowledgeBaseRoute(Route): "chunk_overlap": kb.chunk_overlap or 50, "top_k_dense": kb.top_k_dense or 50, "top_k_sparse": kb.top_k_sparse or 50, - "top_m_final": kb.top_m_final or 5, - "enable_rerank": kb.enable_rerank - if kb.enable_rerank is not None - else True, "created_at": kb.created_at.isoformat(), "updated_at": kb.updated_at.isoformat(), } @@ -578,9 +358,10 @@ class KnowledgeBaseRoute(Route): if not kb_id: return Response().error("缺少参数 kb_id").__dict__ - kb = await kb_manager.get_kb(kb_id) - if not kb: + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: return Response().error("知识库不存在").__dict__ + kb = kb_helper.kb stats = { "kb_id": kb.kb_id, @@ -615,33 +396,19 @@ class KnowledgeBaseRoute(Route): kb_id = request.args.get("kb_id") if not kb_id: return Response().error("缺少参数 kb_id").__dict__ + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) + page_size = request.args.get("page_size", 100, type=int) offset = (page - 1) * page_size limit = page_size - # 使用 KBManagerOps 获取文档列表 - from astrbot.core.knowledge_base.manager_ops import KBManagerOps + doc_list = await kb_helper.list_documents(offset=offset, limit=limit) - ops = KBManagerOps(kb_manager) - docs = await ops.list_documents(kb_id, offset=offset, limit=limit) - - doc_list = [] - for doc in docs: - doc_dict = { - "doc_id": doc.doc_id, - "kb_id": doc.kb_id, - "doc_name": doc.doc_name, - "file_type": doc.file_type, - "file_size": doc.file_size, - "chunk_count": doc.chunk_count, - "media_count": doc.media_count, - "created_at": doc.created_at.isoformat(), - "updated_at": doc.updated_at.isoformat(), - } - doc_list.append(doc_dict) + doc_list = [doc.model_dump() for doc in doc_list] return ( Response() @@ -677,6 +444,7 @@ class KnowledgeBaseRoute(Route): # 检查 Content-Type content_type = request.content_type + kb_id = None if content_type and "multipart/form-data" in content_type: # 方式 1: multipart/form-data @@ -693,10 +461,6 @@ class KnowledgeBaseRoute(Route): file = files["file"] file_name = file.filename - # 使用 aiofiles 异步读取文件内容 - import uuid - import aiofiles - # 保存到临时文件 temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}" await file.save(temp_file_path) @@ -739,9 +503,12 @@ class KnowledgeBaseRoute(Route): # 提取文件类型 file_type = file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ + # 上传文档 - doc = await kb_manager.upload_document( - kb_id=kb_id, + doc = await kb_helper.upload_document( file_name=file_name, file_content=file_content, file_type=file_type, @@ -776,14 +543,17 @@ class KnowledgeBaseRoute(Route): """ try: kb_manager = self._get_kb_manager() + kb_id = request.args.get("kb_id") + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ doc_id = request.args.get("doc_id") if not doc_id: return Response().error("缺少参数 doc_id").__dict__ + kb_helper = await kb_manager.get_kb(kb_id) + if not kb_helper: + return Response().error("知识库不存在").__dict__ - from astrbot.core.knowledge_base.manager_ops import KBManagerOps - - ops = KBManagerOps(kb_manager) - doc = await ops.get_document(doc_id) + doc = await kb_helper.get_document(doc_id) if not doc: return Response().error("文档不存在").__dict__ @@ -809,70 +579,43 @@ class KnowledgeBaseRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"获取文档详情失败: {str(e)}").__dict__ - async def delete_document(self): - """删除文档 - - Body: - - doc_id: 文档 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - doc_id = data.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - - from astrbot.core.knowledge_base.manager_ops import KBManagerOps - - ops = KBManagerOps(kb_manager) - success = await ops.delete_document(doc_id) - if not success: - return Response().error("文档不存在").__dict__ - - return Response().ok(message="删除文档成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除文档失败: {str(e)}").__dict__ - - # ===== 块管理 API ===== - async def list_chunks(self): """获取块列表 Query 参数: - - doc_id: 文档 ID (必填) + - kb_id: 知识库 ID (必填) + - page: 页码 (默认 1) + - page_size: 每页数量 (默认 20) """ try: kb_manager = self._get_kb_manager() + kb_id = request.args.get("kb_id") doc_id = request.args.get("doc_id") + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 100, type=int) + if not kb_id: + return Response().error("缺少参数 kb_id").__dict__ if not doc_id: return Response().error("缺少参数 doc_id").__dict__ - - from astrbot.core.knowledge_base.manager_ops import KBManagerOps - - ops = KBManagerOps(kb_manager) - chunks = await ops.list_chunks(doc_id) - - chunk_list = [] - for chunk in chunks: - chunk_dict = { - "chunk_id": chunk.chunk_id, - "doc_id": chunk.doc_id, - "kb_id": chunk.kb_id, - "chunk_index": chunk.chunk_index, - "content": chunk.content, - "char_count": chunk.char_count, - "created_at": chunk.created_at.isoformat(), - } - chunk_list.append(chunk_dict) - - return Response().ok({"items": chunk_list}).__dict__ - + kb_helper = await kb_manager.get_kb(kb_id) + offset = (page - 1) * page_size + limit = page_size + if not kb_helper: + return Response().error("知识库不存在").__dict__ + chunk_list = await kb_helper.kb_db.get_chunks_by_doc_id( + doc_id=doc_id, offset=offset, limit=limit + ) + return ( + Response() + .ok( + data={ + "items": [chunk.model_dump() for chunk in chunk_list], + "page": page, + "page_size": page_size, + } + ) + .__dict__ + ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: @@ -880,146 +623,6 @@ class KnowledgeBaseRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"获取块列表失败: {str(e)}").__dict__ - async def get_chunk(self): - """获取块详情 - - Query 参数: - - chunk_id: 块 ID (必填) - """ - try: - kb_db = self.kb_db if self.kb_db else self._get_kb_manager() and self.kb_db - chunk_id = request.args.get("chunk_id") - if not chunk_id: - return Response().error("缺少参数 chunk_id").__dict__ - - chunk_data = await kb_db.get_chunk_with_metadata(chunk_id) - if not chunk_data: - return Response().error("块不存在").__dict__ - - chunk = chunk_data["chunk"] - doc = chunk_data["document"] - kb = chunk_data["knowledge_base"] - - chunk_dict = { - "chunk_id": chunk.chunk_id, - "doc_id": chunk.doc_id, - "kb_id": chunk.kb_id, - "chunk_index": chunk.chunk_index, - "content": chunk.content, - "char_count": chunk.char_count, - "created_at": chunk.created_at.isoformat(), - "document": { - "doc_name": doc.doc_name, - "file_type": doc.file_type, - }, - "knowledge_base": { - "kb_name": kb.kb_name, - }, - } - - return Response().ok(chunk_dict).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取块详情失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取块详情失败: {str(e)}").__dict__ - - async def delete_chunk(self): - """删除块 - - Body: - - chunk_id: 块 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - chunk_id = data.get("chunk_id") - if not chunk_id: - return Response().error("缺少参数 chunk_id").__dict__ - - success = await kb_manager.delete_chunk(chunk_id) - if not success: - return Response().error("块不存在").__dict__ - - return Response().ok(message="删除块成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除块失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除块失败: {str(e)}").__dict__ - - # ===== 多媒体管理 API ===== - - async def list_media(self): - """获取多媒体资源列表 - - Query 参数: - - doc_id: 文档 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - doc_id = request.args.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - - media_list = await kb_manager.list_media(doc_id) - - media_result = [] - for media in media_list: - media_dict = { - "media_id": media.media_id, - "doc_id": media.doc_id, - "kb_id": media.kb_id, - "media_type": media.media_type, - "file_name": media.file_name, - "file_path": media.file_path, - "file_size": media.file_size, - "mime_type": media.mime_type, - "created_at": media.created_at.isoformat(), - } - media_result.append(media_dict) - - return Response().ok({"media": media_result}).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取多媒体列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取多媒体列表失败: {str(e)}").__dict__ - - async def delete_media(self): - """删除多媒体资源 - - Body: - - media_id: 多媒体 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - media_id = data.get("media_id") - if not media_id: - return Response().error("缺少参数 media_id").__dict__ - - success = await kb_manager.delete_media(media_id) - if not success: - return Response().error("多媒体资源不存在").__dict__ - - return Response().ok(message="删除多媒体资源成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除多媒体资源失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除多媒体资源失败: {str(e)}").__dict__ - # ===== 检索 API ===== async def retrieve(self): @@ -1033,54 +636,26 @@ class KnowledgeBaseRoute(Route): """ try: kb_manager = self._get_kb_manager() - retrieval_manager = ( - self.retrieval_manager - if self.retrieval_manager - else self._get_kb_manager() and self.retrieval_manager - ) data = await request.json query = data.get("query") - kb_ids = data.get("kb_ids") + kb_names = data.get("kb_names") if not query: return Response().error("缺少参数 query").__dict__ - if not kb_ids or not isinstance(kb_ids, list): - return Response().error("缺少参数 kb_ids 或格式错误").__dict__ + if not kb_names or not isinstance(kb_names, list): + return Response().error("缺少参数 kb_names 或格式错误").__dict__ top_k = data.get("top_k", 5) - enable_rerank = data.get("enable_rerank") - results = await retrieval_manager.retrieve( + results = await kb_manager.retrieve( query=query, - kb_ids=kb_ids, + kb_names=kb_names, top_m_final=top_k, - enable_rerank=enable_rerank, ) - - # 获取manager_ops以查询文档和知识库信息 - from astrbot.core.knowledge_base.manager_ops import KBManagerOps - - ops = KBManagerOps(kb_manager) - result_list = [] - for result in results: - # 查询文档和知识库名称 - doc = await ops.get_document(result.doc_id) - kb = await kb_manager.get_kb(result.kb_id) - - result_dict = { - "chunk_id": result.chunk_id, - "doc_id": result.doc_id, - "kb_id": result.kb_id, - "doc_name": doc.doc_name if doc else "未知文档", - "kb_name": kb.kb_name if kb else "未知知识库", - "chunk_index": result.metadata.get("chunk_index", 0), - "content": result.content, - "char_count": len(result.content), - "score": result.score, - } - result_list.append(result_dict) + if results: + result_list = results["results"] return ( Response() @@ -1094,205 +669,3 @@ class KnowledgeBaseRoute(Route): logger.error(f"检索失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"检索失败: {str(e)}").__dict__ - - # ===== 会话配置 API ===== - - async def get_session_config(self): - """获取会话知识库配置 - - Query 参数: - - session_id: 会话 ID (必填) - """ - try: - kb_db = self.kb_db if self.kb_db else self._get_kb_manager() and self.kb_db - session_id = request.args.get("session_id") - if not session_id: - return Response().error("缺少参数 session_id").__dict__ - - kb_ids = await kb_db.get_session_kb_ids(session_id) - - return Response().ok({"session_id": session_id, "kb_ids": kb_ids}).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取会话配置失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取会话配置失败: {str(e)}").__dict__ - - async def set_session_config(self): - """设置会话知识库配置 - - Body: - - scope: 配置范围 (session/platform) (必填) - - scope_id: 范围标识 (会话 ID 或平台 ID) (必填) - - kb_ids: 知识库 ID 列表 (必填) - - top_k: 返回结果数量 (可选) - - enable_rerank: 是否启用Rerank (可选) - """ - try: - kb_db = self.kb_db if self.kb_db else self._get_kb_manager() and self.kb_db - data = await request.json - - scope = data.get("scope") - scope_id = data.get("scope_id") - kb_ids = data.get("kb_ids") - top_k = data.get("top_k") - enable_rerank = data.get("enable_rerank") - - if not scope or not scope_id: - return Response().error("缺少参数 scope 或 scope_id").__dict__ - if kb_ids is None or not isinstance(kb_ids, list): - return Response().error("缺少参数 kb_ids 或格式错误").__dict__ - - if scope not in ["session", "platform"]: - return Response().error("scope 必须是 session 或 platform").__dict__ - - await kb_db.set_session_kb_ids( - scope=scope, - scope_id=scope_id, - kb_ids=kb_ids, - top_k=top_k, - enable_rerank=enable_rerank, - ) - - return Response().ok(message="设置会话配置成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"设置会话配置失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"设置会话配置失败: {str(e)}").__dict__ - - async def delete_session_config(self): - """删除会话知识库配置 - - Body: - - scope: 配置范围 (session/platform) (必填) - - scope_id: 范围标识 (会话 ID 或平台 ID) (必填) - """ - try: - kb_db = self.kb_db if self.kb_db else self._get_kb_manager() and self.kb_db - data = await request.json - - scope = data.get("scope") - scope_id = data.get("scope_id") - - if not scope or not scope_id: - return Response().error("缺少参数 scope 或 scope_id").__dict__ - - success = await kb_db.delete_session_kb_config( - scope=scope, - scope_id=scope_id, - ) - - if not success: - return Response().error("配置不存在").__dict__ - - return Response().ok(message="删除会话配置成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除会话配置失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除会话配置失败: {str(e)}").__dict__ - - async def list_session_configs(self): - """获取所有会话配置列表 - - Query 参数: - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - """ - try: - kb_db = self.kb_db if self.kb_db else self._get_kb_manager() and self.kb_db - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - - offset = (page - 1) * page_size - limit = page_size - - configs = await kb_db.list_all_session_configs(offset=offset, limit=limit) - - import json - - config_list = [] - for config in configs: - config_dict = { - "config_id": config.config_id, - "scope": config.scope, - "scope_id": config.scope_id, - "kb_ids": json.loads(config.kb_ids), - "created_at": config.created_at.isoformat(), - "updated_at": config.updated_at.isoformat(), - } - config_list.append(config_dict) - - return ( - Response() - .ok({"items": config_list, "page": page, "page_size": page_size}) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取会话配置列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取会话配置列表失败: {str(e)}").__dict__ - - async def list_sessions_by_kb(self): - """获取使用特定知识库的会话列表 - - Query 参数: - - kb_id: 知识库 ID (必填) - """ - try: - kb_db = self.kb_db if self.kb_db else self._get_kb_manager() and self.kb_db - kb_id = request.args.get("kb_id") - - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - # 获取所有会话配置 - configs = await kb_db.list_all_session_configs(offset=0, limit=1000) - - import json - - # 筛选包含该知识库的会话 - session_list = [] - for config in configs: - kb_ids = json.loads(config.kb_ids) - if kb_id in kb_ids: - session_dict = { - "config_id": config.config_id, - "scope": config.scope, - "scope_id": config.scope_id, - "kb_ids": kb_ids, - "top_k": config.top_k, - "enable_rerank": config.enable_rerank, - "created_at": config.created_at.isoformat(), - "updated_at": config.updated_at.isoformat(), - } - session_list.append(session_dict) - - return ( - Response() - .ok( - { - "sessions": session_list, - "total": len(session_list), - "kb_id": kb_id, - } - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库会话列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库会话列表失败: {str(e)}").__dict__ diff --git a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json index b1222a51..35c430aa 100644 --- a/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/en-US/features/knowledge-base/document.json @@ -21,7 +21,8 @@ "delete": "Delete", "preview": "Preview", "search": "Search Chunks", - "searchPlaceholder": "Enter keywords to search chunks..." + "searchPlaceholder": "Enter keywords to search chunks...", + "showing": "Showing" }, "edit": { "title": "Edit Chunk", diff --git a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json index c493cef2..22781666 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json +++ b/dashboard/src/i18n/locales/zh-CN/features/knowledge-base/document.json @@ -21,7 +21,8 @@ "delete": "删除", "preview": "预览", "search": "搜索分块", - "searchPlaceholder": "输入关键词搜索分块内容..." + "searchPlaceholder": "输入关键词搜索分块内容...", + "showing": "显示" }, "edit": { "title": "编辑分块", diff --git a/dashboard/src/views/knowledge-base/DocumentDetail.vue b/dashboard/src/views/knowledge-base/DocumentDetail.vue index 6386632d..991afc37 100644 --- a/dashboard/src/views/knowledge-base/DocumentDetail.vue +++ b/dashboard/src/views/knowledge-base/DocumentDetail.vue @@ -82,7 +82,7 @@ {{ t('chunks.title') }} - {{ chunks.length }} {{ t('chunks.title') }} + {{ totalChunks }} {{ t('chunks.title') }} + + + +
+
+ {{ t('chunks.showing') }} {{ (page - 1) * pageSize + 1 }} - {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }} +
+
+ + +
+
@@ -212,71 +224,6 @@ - - - - - {{ t('edit.title') }} - - - - - - - - - - - - {{ t('edit.cancel') }} - - - {{ t('edit.save') }} - - - - - - - - - {{ t('delete.title') }} - - -

{{ t('delete.confirmText') }}

- - {{ t('delete.warning') }} - -
- - - - - {{ t('delete.cancel') }} - - - {{ t('delete.confirm') }} - - -
-
- {{ snackbar.text }} @@ -299,16 +246,16 @@ const docId = ref(route.params.docId as string) // 状态 const loading = ref(true) const loadingChunks = ref(false) -const saving = ref(false) -const deleting = ref(false) const document = ref({}) const chunks = ref([]) const searchQuery = ref('') const showViewDialog = ref(false) -const showEditDialog = ref(false) -const showDeleteDialog = ref(false) const selectedChunk = ref(null) -const deleteTarget = ref(null) + +// 分页状态 +const page = ref(1) +const pageSize = ref(10) +const totalChunks = ref(0) const snackbar = ref({ show: false, @@ -322,17 +269,12 @@ const showSnackbar = (text: string, color: string = 'success') => { snackbar.value.show = true } -// 编辑表单 -const editForm = ref({ - content: '' -}) - // 表格列 const headers = [ { title: t('chunks.index'), key: 'chunk_index', width: 100 }, { title: t('chunks.content'), key: 'content', sortable: false }, { title: t('chunks.charCount'), key: 'char_count', width: 150 }, - { title: t('chunks.actions'), key: 'actions', sortable: false, align: 'end', width: 150 } + { title: t('chunks.actions'), key: 'actions', sortable: false, width: 150 } ] // 过滤分块 @@ -349,7 +291,7 @@ const loadDocument = async () => { loading.value = true try { const response = await axios.get('/api/kb/document/get', { - params: { doc_id: docId.value } + params: { doc_id: docId.value, kb_id: kbId.value } }) if (response.data.status === 'ok') { document.value = response.data.data @@ -367,10 +309,16 @@ const loadChunks = async () => { loadingChunks.value = true try { const response = await axios.get('/api/kb/chunk/list', { - params: { doc_id: docId.value } + params: { + doc_id: docId.value, + kb_id: kbId.value, + page: page.value, + page_size: pageSize.value + } }) if (response.data.status === 'ok') { chunks.value = response.data.data.items || [] + totalChunks.value = response.data.data.items.length } } catch (error) { console.error('Failed to load chunks:', error) @@ -380,84 +328,24 @@ const loadChunks = async () => { } } +// 处理分页变化 +const handlePageChange = (newPage: number) => { + page.value = newPage + loadChunks() +} + +const handlePageSizeChange = (newPageSize: number) => { + pageSize.value = newPageSize + page.value = 1 + loadChunks() +} + // 查看分块 const viewChunk = (chunk: any) => { selectedChunk.value = chunk showViewDialog.value = true } -// 编辑分块 -const editChunk = (chunk: any) => { - selectedChunk.value = chunk - editForm.value.content = chunk.content - showEditDialog.value = true -} - -// 关闭编辑对话框 -const closeEditDialog = () => { - showEditDialog.value = false - selectedChunk.value = null - editForm.value.content = '' -} - -// 保存分块 -const saveChunk = async () => { - if (!selectedChunk.value) return - - saving.value = true - try { - const response = await axios.post('/api/kb/chunk/update', { - chunk_id: selectedChunk.value.chunk_id, - content: editForm.value.content - }) - - if (response.data.status === 'ok') { - showSnackbar(t('edit.saveSuccess')) - closeEditDialog() - await loadChunks() - } else { - showSnackbar(response.data.message || t('edit.saveFailed'), 'error') - } - } catch (error) { - console.error('Failed to save chunk:', error) - showSnackbar(t('edit.saveFailed'), 'error') - } finally { - saving.value = false - } -} - -// 确认删除分块 -const confirmDeleteChunk = (chunk: any) => { - deleteTarget.value = chunk - showDeleteDialog.value = true -} - -// 删除分块 -const deleteChunk = async () => { - if (!deleteTarget.value) return - - deleting.value = true - try { - const response = await axios.post('/api/kb/chunk/delete', { - chunk_id: deleteTarget.value.chunk_id - }) - - if (response.data.status === 'ok') { - showSnackbar(t('delete.deleteSuccess')) - showDeleteDialog.value = false - await loadChunks() - await loadDocument() - } else { - showSnackbar(response.data.message || t('delete.deleteFailed'), 'error') - } - } catch (error) { - console.error('Failed to delete chunk:', error) - showSnackbar(t('delete.deleteFailed'), 'error') - } finally { - deleting.value = false - } -} - // 工具函数 const getFileIcon = (fileType: string) => { const type = fileType?.toLowerCase() || '' @@ -582,6 +470,10 @@ onMounted(() => { font-family: 'Consolas', 'Monaco', monospace; } +.gap-2 { + gap: 8px; +} + /* 响应式设计 */ @media (max-width: 768px) { .document-detail-page { diff --git a/dashboard/src/views/knowledge-base/KBDetail.vue b/dashboard/src/views/knowledge-base/KBDetail.vue index 61008361..ecec2465 100644 --- a/dashboard/src/views/knowledge-base/KBDetail.vue +++ b/dashboard/src/views/knowledge-base/KBDetail.vue @@ -51,7 +51,7 @@ - + @@ -163,7 +163,7 @@ - + @@ -329,12 +329,11 @@ onMounted(() => { padding: 24px; text-align: center; border-radius: 12px; - background: rgba(var(--v-theme-surface-variant), 0.3); + background: rgba(var(--v-theme-surface-variant), 0.1); transition: all 0.3s ease; } .stat-box:hover { - transform: translateY(-4px); background: rgba(var(--v-theme-surface-variant), 0.5); } @@ -342,12 +341,10 @@ onMounted(() => { font-size: 2rem; font-weight: 600; margin-top: 8px; - color: rgb(var(--v-theme-on-surface)); } .stat-label { font-size: 0.875rem; - color: rgb(var(--v-theme-on-surface-variant)); margin-top: 4px; } diff --git a/dashboard/src/views/knowledge-base/components/RetrievalTab.vue b/dashboard/src/views/knowledge-base/components/RetrievalTab.vue index 2bc5e634..249dba7e 100644 --- a/dashboard/src/views/knowledge-base/components/RetrievalTab.vue +++ b/dashboard/src/views/knowledge-base/components/RetrievalTab.vue @@ -7,10 +7,16 @@ + - + - - - - 如果没有配置重排序模型提供商,将跳过重排序步骤 - @@ -143,7 +136,8 @@ import { useModuleI18n } from '@/i18n/composables' const { tm: t } = useModuleI18n('features/knowledge-base/detail') const props = defineProps<{ - kbId: string + kbId: string, + kbName: string, }>() // 状态 @@ -179,7 +173,7 @@ const performRetrieval = async () => { try { const response = await axios.post('/api/kb/retrieve', { query: query.value, - kb_ids: [props.kbId], + kb_names: [props.kbName], top_k: topK.value, enable_rerank: enableRerank.value })