improve
This commit is contained in:
@@ -134,27 +134,9 @@ DEFAULT_CONFIG = {
|
||||
"persona": [], # deprecated
|
||||
"timezone": "Asia/Shanghai",
|
||||
"callback_api_base": "",
|
||||
"default_kb_collection": "", # 默认知识库名称
|
||||
"default_kb_collection": "", # 默认知识库名称, 已经过时
|
||||
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
|
||||
"knowledge_base": {
|
||||
"enabled": False, # 默认禁用,用户需要主动启用
|
||||
"embedding_provider_id": "", # 嵌入模型提供商 ID (为空时自动选择第一个)
|
||||
"rerank_provider_id": "", # 重排序模型提供商 ID (为空时自动选择第一个)
|
||||
"storage": {
|
||||
"files_path": "data/knowledge_base", # 文件存储路径
|
||||
"vector_db_path": "data/knowledge_base/vectors", # 向量数据库路径
|
||||
},
|
||||
"chunking": {
|
||||
"chunk_size": 512, # 文档块大小(字符数)
|
||||
"chunk_overlap": 50, # 文档块重叠大小(字符数)
|
||||
},
|
||||
"retrieval": {
|
||||
"top_k_dense": 50, # 密集检索返回结果数
|
||||
"top_k_sparse": 50, # 稀疏检索返回结果数
|
||||
"top_m_final": 5, # 最终融合后返回的结果数
|
||||
"enable_rerank": True, # 是否启用重排序
|
||||
},
|
||||
},
|
||||
"kb_names": [], # 默认知识库名称列表
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
@@ -112,9 +112,7 @@ class AstrBotCoreLifecycle:
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(
|
||||
self.astrbot_config, self.provider_manager
|
||||
)
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -141,10 +139,6 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
# 注册知识库会话生命周期钩子(零侵入级联清理)
|
||||
if self.kb_manager.is_initialized:
|
||||
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
知识库管理模块
|
||||
|
||||
提供文档上传、解析、分块、向量化、检索等功能
|
||||
"""
|
||||
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
# 注意: 以下导入在对应模块实现后取消注释
|
||||
from .database import KBDatabase
|
||||
from .manager import KBManager
|
||||
from .manager_ops import KBManagerOps
|
||||
from .session_config_db import SessionConfigDB
|
||||
|
||||
# from .injector import KnowledgeBaseInjector
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBase",
|
||||
"KBDocument",
|
||||
"KBChunk",
|
||||
"KBMedia",
|
||||
"KBSessionConfig",
|
||||
"KBDatabase",
|
||||
"SessionConfigDB",
|
||||
"KBManager",
|
||||
"KBManagerOps",
|
||||
# "KnowledgeBaseInjector",
|
||||
]
|
||||
@@ -1,183 +0,0 @@
|
||||
"""知识库数据库操作类
|
||||
|
||||
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
|
||||
|
||||
注意:
|
||||
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
|
||||
- 会话配置也存储在此数据库中,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
|
||||
class KBDatabase:
|
||||
"""知识库数据库操作类
|
||||
|
||||
职责:
|
||||
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
|
||||
- 统一异常处理
|
||||
|
||||
注意:
|
||||
- 该类操作独立的知识库数据库 (kb.db)
|
||||
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
"""初始化知识库数据库操作类
|
||||
|
||||
Args:
|
||||
kb_db: 知识库独立数据库实例,而非主数据库
|
||||
"""
|
||||
self.db = kb_db
|
||||
|
||||
# ===== 知识库查询 =====
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_kbs(self) -> int:
|
||||
"""统计知识库数量"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(func.count(KnowledgeBase.id))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(KBDocument.kb_id == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBDocument.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 块查询 =====
|
||||
|
||||
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
|
||||
"""根据 ID 获取块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
|
||||
"""根据知识库 ID 列表获取所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
|
||||
"""根据向量文档 ID 获取块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
|
||||
"""获取块及其关联的文档和知识库元数据"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk, KBDocument, KnowledgeBase)
|
||||
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
|
||||
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
|
||||
.where(KBChunk.chunk_id == chunk_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
chunk, doc, kb = row
|
||||
return {
|
||||
"chunk": chunk,
|
||||
"document": doc,
|
||||
"knowledge_base": kb,
|
||||
}
|
||||
|
||||
async def list_chunks_by_doc(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(KBChunk.doc_id == doc_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBChunk.chunk_index)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -1,112 +0,0 @@
|
||||
"""知识库上下文注入器
|
||||
|
||||
负责检索相关知识并格式化为 LLM 可用的上下文文本
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.manager import (
|
||||
RetrievalManager,
|
||||
RetrievalResult,
|
||||
)
|
||||
from .vec_db_factory import VecDBFactory
|
||||
|
||||
|
||||
class KnowledgeBaseInjector:
|
||||
"""知识库上下文注入器
|
||||
|
||||
职责:
|
||||
- 检索相关知识
|
||||
- 格式化为上下文文本
|
||||
- 注入到 LLM Prompt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBDatabase,
|
||||
vec_db_factory: VecDBFactory,
|
||||
retrieval_manager: RetrievalManager,
|
||||
):
|
||||
"""初始化知识库上下文注入器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
retrieval_manager: 检索管理器实例
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.vec_db_factory = vec_db_factory
|
||||
self.retrieval_manager = retrieval_manager
|
||||
|
||||
async def retrieve_and_inject(
|
||||
self,
|
||||
kb_ids: list[str],
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> Optional[dict]:
|
||||
"""检索并注入知识库上下文
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
|
||||
{
|
||||
"context_text": str, # 格式化的上下文文本
|
||||
"results": List[dict], # 原始检索结果列表
|
||||
}
|
||||
"""
|
||||
# 2. 检索知识
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_m_final=top_k,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# 3. 格式化上下文
|
||||
context_text = self._format_context(results)
|
||||
|
||||
# 4. 转换结果为字典格式
|
||||
results_dict = [
|
||||
{
|
||||
"chunk_id": r.chunk_id,
|
||||
"doc_id": r.doc_id,
|
||||
"kb_id": r.kb_id,
|
||||
"kb_name": r.kb_name,
|
||||
"doc_name": r.doc_name,
|
||||
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"context_text": context_text,
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
def _format_context(self, results: List[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
lines.append(f"【知识 {i}】")
|
||||
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||
lines.append(f"内容: {result.content}")
|
||||
lines.append(f"相关度: {result.score:.2f}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,383 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import SQLModel, col, desc
|
||||
from sqlalchemy import text, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
"""初始化知识库数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# 确保目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建异步引擎
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""获取数据库会话
|
||||
|
||||
用法:
|
||||
async with kb_db.get_db() as session:
|
||||
# 执行数据库操作
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
self.inited = True
|
||||
|
||||
async def migrate_to_v1(self) -> None:
|
||||
"""执行知识库数据库 v1 迁移
|
||||
|
||||
创建所有必要的索引以优化查询性能
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
# 创建知识库表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建块表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
|
||||
"ON kb_chunks(chunk_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
|
||||
"ON kb_chunks(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
|
||||
"ON kb_chunks(vec_doc_id)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建会话配置表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
|
||||
"ON kb_session_config(scope_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
|
||||
"ON kb_session_config(scope)"
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(desc(KnowledgeBase.created_at))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_kbs(self) -> int:
|
||||
"""统计知识库数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KnowledgeBase.id)))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(desc(KBDocument.created_at))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KBDocument.id))).where(
|
||||
col(KBDocument.kb_id) == kb_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 块查询 =====
|
||||
|
||||
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
|
||||
"""根据 ID 获取块"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBChunk).where(col(KBChunk.chunk_id) == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
|
||||
"""根据知识库 ID 列表获取所有块"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBChunk).where(col(KBChunk.kb_id).in_(kb_ids))
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
|
||||
"""根据向量文档 ID 获取块"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBChunk).where(col(KBChunk.vec_doc_id) == vec_doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunks_by_doc_id(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBChunk]:
|
||||
"""根据文档 ID 获取所有块"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(col(KBChunk.doc_id) == doc_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(col(KBChunk.chunk_index))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
|
||||
"""获取块及其关联的文档和知识库元数据"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk, KBDocument, KnowledgeBase)
|
||||
.join(KBDocument, col(KBChunk.doc_id) == col(KBDocument.doc_id))
|
||||
.join(KnowledgeBase, col(KBChunk.kb_id) == col(KnowledgeBase.kb_id))
|
||||
.where(col(KBChunk.chunk_id) == chunk_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
chunk, doc, kb = row
|
||||
return {
|
||||
"chunk": chunk,
|
||||
"document": doc,
|
||||
"knowledge_base": kb,
|
||||
}
|
||||
|
||||
async def list_chunks_by_doc(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(col(KBChunk.doc_id) == doc_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(col(KBChunk.chunk_index))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_kb_stats(self, kb_id: str):
|
||||
"""更新知识库统计信息"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
update_stmt = (
|
||||
update(KnowledgeBase)
|
||||
.where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
.values(
|
||||
doc_count=select(func.count(col(KBDocument.id)))
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.scalar_subquery(),
|
||||
chunk_count=select(func.count(col(KBChunk.id)))
|
||||
.where(col(KBChunk.kb_id) == kb_id)
|
||||
.scalar_subquery(),
|
||||
)
|
||||
)
|
||||
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
@@ -0,0 +1,248 @@
|
||||
import uuid
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from .models import KnowledgeBase, KBDocument, KBChunk, KBMedia
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from .parsers.base import BaseParser
|
||||
from .chunking.base import BaseChunker
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class KBHelper:
|
||||
vec_db: BaseVecDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
kb: KnowledgeBase,
|
||||
provider_manager: ProviderManager,
|
||||
kb_root_dir: str,
|
||||
chunker: BaseChunker,
|
||||
parsers: dict[str, BaseParser],
|
||||
):
|
||||
self.kb_db = kb_db
|
||||
self.kb = kb
|
||||
self.prov_mgr = provider_manager
|
||||
self.kb_root_dir = kb_root_dir
|
||||
self.parsers = parsers
|
||||
self.chunker = chunker
|
||||
|
||||
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
|
||||
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
|
||||
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
|
||||
|
||||
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def initialize(self):
|
||||
await self._ensure_vec_db()
|
||||
|
||||
async def get_ep(self) -> EmbeddingProvider:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.embedding_provider_id
|
||||
) # type: ignore
|
||||
if not ep:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
|
||||
)
|
||||
return ep
|
||||
|
||||
async def get_rp(self) -> RerankProvider | None:
|
||||
if not self.kb.rerank_provider_id:
|
||||
return None
|
||||
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.rerank_provider_id
|
||||
) # type: ignore
|
||||
if not rp:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
|
||||
)
|
||||
return rp
|
||||
|
||||
async def _ensure_vec_db(self) -> FaissVecDB:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
|
||||
ep = await self.get_ep()
|
||||
rp = await self.get_rp()
|
||||
|
||||
vec_db = FaissVecDB(
|
||||
doc_store_path=str(self.kb_dir / "doc.db"),
|
||||
index_store_path=str(self.kb_dir / "index.faiss"),
|
||||
embedding_provider=ep,
|
||||
rerank_provider=rp,
|
||||
)
|
||||
await vec_db.initialize()
|
||||
self.vec_db = vec_db
|
||||
return vec_db
|
||||
|
||||
async def delete_vec_db(self):
|
||||
await self.terminate()
|
||||
if self.kb_dir.exists():
|
||||
for item in self.kb_dir.iterdir():
|
||||
if item.is_file():
|
||||
item.unlink()
|
||||
self.kb_dir.rmdir()
|
||||
|
||||
async def terminate(self):
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
流程:
|
||||
1. 保存原始文件
|
||||
2. 解析文档内容
|
||||
3. 提取多媒体资源 (TODO)
|
||||
4. 分块处理
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
"""
|
||||
await self._ensure_vec_db()
|
||||
doc_id = str(uuid.uuid4())
|
||||
media_paths: list[Path] = []
|
||||
vec_doc_ids = []
|
||||
|
||||
file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(file_content)
|
||||
|
||||
try:
|
||||
parser = self.parsers.get(file_type)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件类型: {file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
# 保存媒体文件
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 分块并生成向量
|
||||
saved_chunks = []
|
||||
chunks_text = await self.chunker.chunk(text_content)
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
vec_doc_id = await self.vec_db.insert(
|
||||
content=chunk_text,
|
||||
metadata={
|
||||
"kb_id": self.kb.kb_id,
|
||||
"doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
},
|
||||
)
|
||||
vec_doc_ids.append(str(vec_doc_id))
|
||||
|
||||
chunk = KBChunk(
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
chunk_index=idx,
|
||||
content=chunk_text,
|
||||
char_count=len(chunk_text),
|
||||
vec_doc_id=str(vec_doc_id),
|
||||
)
|
||||
saved_chunks.append(chunk)
|
||||
|
||||
# 保存文档和块的元数据
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
file_path=str(file_path),
|
||||
chunk_count=len(saved_chunks),
|
||||
media_count=0,
|
||||
)
|
||||
async with self.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
for chunk in saved_chunks:
|
||||
session.add(chunk)
|
||||
for media in saved_media:
|
||||
session.add(media)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id)
|
||||
|
||||
return doc
|
||||
except Exception as e:
|
||||
logger.error(f"上传文档失败: {e}")
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
for media_path in media_paths:
|
||||
try:
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as me:
|
||||
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||
|
||||
raise e
|
||||
|
||||
async def list_documents(
|
||||
self, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
|
||||
return docs
|
||||
|
||||
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||
"""获取单个文档"""
|
||||
doc = await self.kb_db.get_document_by_id(doc_id)
|
||||
return doc
|
||||
|
||||
async def _save_media(
|
||||
self,
|
||||
doc_id: str,
|
||||
media_type: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
mime_type: str,
|
||||
) -> KBMedia:
|
||||
"""保存多媒体资源"""
|
||||
media_id = str(uuid.uuid4())
|
||||
ext = Path(file_name).suffix
|
||||
|
||||
# 保存文件
|
||||
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
media = KBMedia(
|
||||
media_id=media_id,
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
media_type=media_type,
|
||||
file_name=file_name,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
return media
|
||||
@@ -1,205 +0,0 @@
|
||||
"""
|
||||
知识库管理器
|
||||
负责知识库模块的初始化、配置和资源管理
|
||||
|
||||
架构说明:
|
||||
- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db)
|
||||
- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from .injector import KnowledgeBaseInjector
|
||||
from .retrieval.manager import RetrievalManager
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .kb_sqlite import KBSQLiteDatabase
|
||||
from .database import KBDatabase
|
||||
from .vec_db_factory import VecDBFactory
|
||||
from .manager import KBManager
|
||||
from .parsers.text_parser import TextParser
|
||||
from .parsers.pdf_parser import PDFParser
|
||||
from .chunking.fixed_size import FixedSizeChunker
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
"""知识库管理器
|
||||
|
||||
职责:
|
||||
- 知识库模块的初始化
|
||||
- Embedding Provider 和 Rerank Provider 的选择
|
||||
- 各个子组件的协调管理
|
||||
- 注册会话删除回调,实现级联清理
|
||||
|
||||
架构说明:
|
||||
- 知识库数据存储在独立数据库 (kb.db)
|
||||
- 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库
|
||||
- 通过回调机制实现与主数据库的生命周期同步
|
||||
"""
|
||||
|
||||
kb_db: KBSQLiteDatabase
|
||||
vec_db_factory: VecDBFactory
|
||||
kb_database: KBDatabase
|
||||
kb_manager: KBManager
|
||||
retrieval_manager: RetrievalManager
|
||||
kb_injector: KnowledgeBaseInjector
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
"""初始化知识库管理器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
provider_manager: Provider 管理器
|
||||
"""
|
||||
self.config = config.get("knowledge_base", {})
|
||||
self.provider_manager = provider_manager
|
||||
self._initialized = False
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化知识库模块"""
|
||||
if not self.config.get("enabled", False):
|
||||
logger.info("知识库功能未启用")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 初始化数据库
|
||||
await self._init_kb_database()
|
||||
|
||||
# 初始化向量数据库工厂
|
||||
await self._init_vector_db_factory()
|
||||
|
||||
# 初始化解析器和分块器
|
||||
parsers = {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
chunking_config = self.config.get("chunking", {})
|
||||
chunker = FixedSizeChunker(
|
||||
chunk_size=chunking_config.get("chunk_size", 512),
|
||||
chunk_overlap=chunking_config.get("chunk_overlap", 50),
|
||||
)
|
||||
|
||||
# 初始化知识库管理器
|
||||
files_path = self.config.get("storage", {}).get(
|
||||
"files_path", "data/knowledge_base"
|
||||
)
|
||||
self.kb_manager = KBManager(
|
||||
db=self.kb_db,
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
storage_path=files_path,
|
||||
parsers=parsers,
|
||||
chunker=chunker,
|
||||
provider_manager=self.provider_manager,
|
||||
)
|
||||
|
||||
# 初始化检索管理器
|
||||
sparse_retriever = SparseRetriever(self.kb_database)
|
||||
rank_fusion = RankFusion(self.kb_database)
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_database,
|
||||
)
|
||||
|
||||
# 初始化上下文注入器
|
||||
self.kb_injector = KnowledgeBaseInjector(
|
||||
kb_db=self.kb_database,
|
||||
vec_db_factory=self.vec_db_factory,
|
||||
retrieval_manager=self.retrieval_manager,
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("知识库模块初始化完成")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"知识库模块导入失败: {e}")
|
||||
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||
except Exception as e:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
"""初始化知识库独立数据库"""
|
||||
db_path = self.config.get("storage", {}).get(
|
||||
"kb_db_path", "data/knowledge_base/kb.db"
|
||||
)
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.kb_db = KBSQLiteDatabase(db_path)
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
self.kb_database = KBDatabase(self.kb_db)
|
||||
logger.info(f"KnowledgeBase database initialized: {db_path}")
|
||||
|
||||
async def _init_vector_db_factory(self):
|
||||
"""初始化向量数据库工厂"""
|
||||
storage_path = self.config.get("storage", {}).get(
|
||||
"vector_db_path", "data/knowledge_base/vectors"
|
||||
)
|
||||
Path(storage_path).mkdir(parents=True, exist_ok=True)
|
||||
self.vec_db_factory = VecDBFactory(storage_base_path=storage_path)
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""检查是否已初始化"""
|
||||
return self._initialized
|
||||
|
||||
def get_kb_manager(self):
|
||||
"""获取知识库管理器"""
|
||||
return self.kb_manager if self._initialized else None
|
||||
|
||||
def get_kb_injector(self):
|
||||
"""获取知识库上下文注入器"""
|
||||
return self.kb_injector if self._initialized else None
|
||||
|
||||
async def reinitialize(self):
|
||||
"""重新初始化知识库模块
|
||||
|
||||
用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后)
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.info("知识库模块已初始化,将重新初始化")
|
||||
await self.terminate()
|
||||
|
||||
await self.initialize()
|
||||
return self._initialized
|
||||
|
||||
async def terminate(self):
|
||||
"""终止知识库模块,清理资源"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("正在终止知识库模块...")
|
||||
|
||||
# 关闭向量数据库工厂(关闭所有向量数据库实例)
|
||||
if self.vec_db_factory:
|
||||
try:
|
||||
await self.vec_db_factory.close_all()
|
||||
logger.debug("向量数据库工厂已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库工厂时出错: {e}")
|
||||
|
||||
# 关闭知识库独立数据库连接
|
||||
if self.kb_db:
|
||||
try:
|
||||
await self.kb_db.close()
|
||||
logger.debug("知识库数据库已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭知识库数据库时出错: {e}")
|
||||
|
||||
self._initialized = False
|
||||
|
||||
logger.info("知识库模块已终止")
|
||||
@@ -0,0 +1,275 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
|
||||
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
from .parsers.text_parser import TextParser
|
||||
from .parsers.pdf_parser import PDFParser
|
||||
from .chunking.fixed_size import FixedSizeChunker
|
||||
from .kb_helper import KBHelper
|
||||
|
||||
from .models import KnowledgeBase
|
||||
|
||||
|
||||
FILES_PATH = "data/knowledge_base"
|
||||
DB_PATH = Path(FILES_PATH) / "kb.db"
|
||||
"""Knowledge Base storage root directory"""
|
||||
PARSERS = {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
CHUNKER = FixedSizeChunker()
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
kb_db: KBSQLiteDatabase
|
||||
retrieval_manager: RetrievalManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
self.provider_manager = provider_manager
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
self.kb_insts: dict[str, KBHelper] = {}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化知识库模块"""
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 初始化数据库
|
||||
await self._init_kb_database()
|
||||
|
||||
# 初始化检索管理器
|
||||
sparse_retriever = SparseRetriever(self.kb_db)
|
||||
rank_fusion = RankFusion(self.kb_db)
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_db,
|
||||
)
|
||||
await self.load_kbs()
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"知识库模块导入失败: {e}")
|
||||
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||
except Exception as e:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
|
||||
|
||||
async def load_kbs(self):
|
||||
"""加载所有知识库实例"""
|
||||
kb_records = await self.kb_db.list_kbs()
|
||||
for record in kb_records:
|
||||
kb_helper = KBHelper(
|
||||
kb_db=self.kb_db,
|
||||
kb=record,
|
||||
provider_manager=self.provider_manager,
|
||||
kb_root_dir=FILES_PATH,
|
||||
chunker=CHUNKER,
|
||||
parsers=PARSERS,
|
||||
)
|
||||
await kb_helper.initialize()
|
||||
self.kb_insts[record.kb_id] = kb_helper
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
kb_name: str,
|
||||
description: str | None = None,
|
||||
emoji: str | None = None,
|
||||
embedding_provider_id: str | None = None,
|
||||
rerank_provider_id: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
top_k_dense: int | None = None,
|
||||
top_k_sparse: int | None = None,
|
||||
top_m_final: int | None = None,
|
||||
) -> KBHelper:
|
||||
"""创建新的知识库实例"""
|
||||
kb = KnowledgeBase(
|
||||
kb_name=kb_name,
|
||||
description=description,
|
||||
emoji=emoji or "📚",
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||
)
|
||||
async with self.kb_db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
|
||||
kb_helper = KBHelper(
|
||||
kb_db=self.kb_db,
|
||||
kb=kb,
|
||||
provider_manager=self.provider_manager,
|
||||
kb_root_dir=FILES_PATH,
|
||||
chunker=CHUNKER,
|
||||
parsers=PARSERS,
|
||||
)
|
||||
await kb_helper.initialize()
|
||||
self.kb_insts[kb.kb_id] = kb_helper
|
||||
return kb_helper
|
||||
|
||||
async def get_kb(self, kb_id: str) -> KBHelper | None:
|
||||
"""获取知识库实例"""
|
||||
if kb_id in self.kb_insts:
|
||||
return self.kb_insts[kb_id]
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
|
||||
"""通过名称获取知识库实例"""
|
||||
for kb_helper in self.kb_insts.values():
|
||||
if kb_helper.kb.kb_name == kb_name:
|
||||
return kb_helper
|
||||
return None
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除知识库实例"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return False
|
||||
|
||||
await kb_helper.delete_vec_db()
|
||||
async with self.kb_db.get_db() as session:
|
||||
await session.delete(kb_helper.kb)
|
||||
await session.commit()
|
||||
|
||||
self.kb_insts.pop(kb_id, None)
|
||||
return True
|
||||
|
||||
async def list_kbs(self) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库实例"""
|
||||
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
|
||||
return kbs
|
||||
|
||||
async def update_kb(
|
||||
self,
|
||||
kb_id: str,
|
||||
kb_name: str,
|
||||
description: str | None = None,
|
||||
emoji: str | None = None,
|
||||
embedding_provider_id: str | None = None,
|
||||
rerank_provider_id: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
top_k_dense: int | None = None,
|
||||
top_k_sparse: int | None = None,
|
||||
top_m_final: int | None = None,
|
||||
) -> KBHelper | None:
|
||||
"""更新知识库实例"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return None
|
||||
|
||||
kb = kb_helper.kb
|
||||
if kb_name is not None:
|
||||
kb.kb_name = kb_name
|
||||
if description is not None:
|
||||
kb.description = description
|
||||
if emoji is not None:
|
||||
kb.emoji = emoji
|
||||
if embedding_provider_id is not None:
|
||||
kb.embedding_provider_id = embedding_provider_id
|
||||
if rerank_provider_id is not None:
|
||||
kb.rerank_provider_id = rerank_provider_id
|
||||
if chunk_size is not None:
|
||||
kb.chunk_size = chunk_size
|
||||
if chunk_overlap is not None:
|
||||
kb.chunk_overlap = chunk_overlap
|
||||
if top_k_dense is not None:
|
||||
kb.top_k_dense = top_k_dense
|
||||
if top_k_sparse is not None:
|
||||
kb.top_k_sparse = top_k_sparse
|
||||
if top_m_final is not None:
|
||||
kb.top_m_final = top_m_final
|
||||
async with self.kb_db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_names: list[str],
|
||||
top_m_final: int = 5,
|
||||
) -> dict | None:
|
||||
"""从指定知识库中检索相关内容"""
|
||||
kb_ids = []
|
||||
kb_id_helper_map = {}
|
||||
for kb_name in kb_names:
|
||||
if kb_helper := await self.get_kb_by_name(kb_name):
|
||||
kb_ids.append(kb_helper.kb.kb_id)
|
||||
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
|
||||
|
||||
if not kb_ids:
|
||||
return {}
|
||||
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
kb_id_helper_map=kb_id_helper_map,
|
||||
top_m_final=top_m_final,
|
||||
)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
context_text = self._format_context(results)
|
||||
|
||||
results_dict = [
|
||||
{
|
||||
"chunk_id": r.chunk_id,
|
||||
"doc_id": r.doc_id,
|
||||
"kb_id": r.kb_id,
|
||||
"kb_name": r.kb_name,
|
||||
"doc_name": r.doc_name,
|
||||
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"context_text": context_text,
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
def _format_context(self, results: list[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
lines.append(f"【知识 {i}】")
|
||||
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||
lines.append(f"内容: {result.content}")
|
||||
lines.append(f"相关度: {result.score:.2f}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -1,230 +0,0 @@
|
||||
"""
|
||||
知识库独立 SQLite 数据库
|
||||
|
||||
该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。
|
||||
职责:
|
||||
- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media)
|
||||
- 提供数据库连接和会话管理
|
||||
- 执行数据库迁移和初始化
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
"""知识库独立 SQLite 数据库
|
||||
|
||||
与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。
|
||||
|
||||
特点:
|
||||
- 数据隔离: 知识库数据不会影响主数据库格式
|
||||
- 独立备份: 可以单独备份和恢复知识库数据
|
||||
- 性能隔离: 大量知识库查询不会影响主业务性能
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
"""初始化知识库数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# 确保目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建异步引擎
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""获取数据库会话
|
||||
|
||||
用法:
|
||||
async with kb_db.get_db() as session:
|
||||
# 执行数据库操作
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
# noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表
|
||||
from astrbot.core.knowledge_base.models import ( # noqa: F401
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
self.inited = True
|
||||
logger.info(f"知识库数据库已初始化: {self.db_path}")
|
||||
|
||||
async def migrate_to_v1(self) -> None:
|
||||
"""执行知识库数据库 v1 迁移
|
||||
|
||||
创建所有必要的索引以优化查询性能
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
# 创建知识库表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建块表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
|
||||
"ON kb_chunks(chunk_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
|
||||
"ON kb_chunks(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
|
||||
"ON kb_chunks(vec_doc_id)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建会话配置表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
|
||||
"ON kb_session_config(scope_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
|
||||
"ON kb_session_config(scope)"
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("知识库数据库迁移 v1 完成")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
@@ -1,430 +0,0 @@
|
||||
"""知识库管理器
|
||||
|
||||
该模块提供知识库的CRUD操作和文档上传处理流程。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from sqlalchemy import func, select, update
|
||||
|
||||
from .kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.chunking.base import BaseChunker
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
|
||||
from astrbot.core.knowledge_base.parsers.base import BaseParser
|
||||
from .vec_db_factory import VecDBFactory
|
||||
|
||||
class KBManager:
|
||||
"""知识库管理器
|
||||
|
||||
职责:
|
||||
- 知识库的 CRUD 操作
|
||||
- 文档上传与解析
|
||||
- 文档块生成与存储
|
||||
- 多媒体资源管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: KBSQLiteDatabase,
|
||||
vec_db_factory: VecDBFactory,
|
||||
storage_path: str,
|
||||
parsers: dict[str, BaseParser],
|
||||
chunker: BaseChunker,
|
||||
provider_manager=None,
|
||||
):
|
||||
self.db = db
|
||||
self.vec_db_factory = vec_db_factory
|
||||
self.storage_path = Path(storage_path)
|
||||
self.media_path = self.storage_path / "media"
|
||||
self.files_path = self.storage_path / "files"
|
||||
self.parsers = parsers
|
||||
self.chunker = chunker
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
# 确保目录存在
|
||||
self.media_path.mkdir(parents=True, exist_ok=True)
|
||||
self.files_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _get_embedding_provider_for_kb(self, kb_id: str):
|
||||
"""根据知识库配置获取 Embedding Provider
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider: Embedding Provider 实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果找不到合适的 embedding provider
|
||||
"""
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
|
||||
# 获取知识库配置
|
||||
kb_database = KBDatabase(self.db)
|
||||
kb = await kb_database.get_kb_by_id(kb_id)
|
||||
if not kb:
|
||||
raise ValueError(f"知识库不存在: {kb_id}")
|
||||
|
||||
embedding_provider_id = kb.embedding_provider_id
|
||||
|
||||
# 如果没有 provider_manager,使用默认的第一个
|
||||
if not self.provider_manager:
|
||||
raise ValueError("Provider Manager 未初始化")
|
||||
|
||||
embedding_providers = self.provider_manager.embedding_provider_insts
|
||||
if not embedding_providers:
|
||||
raise ValueError("系统中没有可用的 Embedding Provider")
|
||||
|
||||
# 如果指定了 provider ID,则查找该 provider
|
||||
if embedding_provider_id:
|
||||
for provider in embedding_providers:
|
||||
if provider.meta().id == embedding_provider_id:
|
||||
return provider
|
||||
raise ValueError(
|
||||
f"未找到配置的 Embedding Provider: {embedding_provider_id}"
|
||||
)
|
||||
|
||||
# 使用第一个可用的 provider
|
||||
return embedding_providers[0]
|
||||
|
||||
# ===== 知识库操作 =====
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
kb_name: str,
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[str] = None,
|
||||
embedding_provider_id: Optional[str] = None,
|
||||
rerank_provider_id: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
top_k_dense: Optional[int] = None,
|
||||
top_k_sparse: Optional[int] = None,
|
||||
top_m_final: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KnowledgeBase:
|
||||
"""创建知识库
|
||||
|
||||
Args:
|
||||
enable_rerank: 是否启用重排序。
|
||||
- 如果明确传入 True/False,则使用该值
|
||||
- 如果为 None,则根据是否有可用的 rerank provider 自动决定
|
||||
"""
|
||||
# 智能决定 enable_rerank 的默认值
|
||||
if enable_rerank is None:
|
||||
# 检查是否有可用的 rerank provider
|
||||
has_rerank_provider = (
|
||||
self.provider_manager
|
||||
and hasattr(self.provider_manager, "rerank_provider_insts")
|
||||
and len(self.provider_manager.rerank_provider_insts) > 0
|
||||
)
|
||||
enable_rerank = has_rerank_provider
|
||||
|
||||
kb = KnowledgeBase(
|
||||
kb_name=kb_name,
|
||||
description=description,
|
||||
emoji=emoji or "📚",
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
async with self.db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return kb
|
||||
|
||||
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_kb(
|
||||
self,
|
||||
kb_id: str,
|
||||
kb_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[str] = None,
|
||||
embedding_provider_id: Optional[str] = None,
|
||||
rerank_provider_id: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
top_k_dense: Optional[int] = None,
|
||||
top_k_sparse: Optional[int] = None,
|
||||
top_m_final: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> Optional[KnowledgeBase]:
|
||||
"""更新知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
if kb_name is not None:
|
||||
kb.kb_name = kb_name
|
||||
if description is not None:
|
||||
kb.description = description
|
||||
if emoji is not None:
|
||||
kb.emoji = emoji
|
||||
if embedding_provider_id is not None:
|
||||
kb.embedding_provider_id = embedding_provider_id
|
||||
if rerank_provider_id is not None:
|
||||
kb.rerank_provider_id = rerank_provider_id
|
||||
if chunk_size is not None:
|
||||
kb.chunk_size = chunk_size
|
||||
if chunk_overlap is not None:
|
||||
kb.chunk_overlap = chunk_overlap
|
||||
if top_k_dense is not None:
|
||||
kb.top_k_dense = top_k_dense
|
||||
if top_k_sparse is not None:
|
||||
kb.top_k_sparse = top_k_sparse
|
||||
if top_m_final is not None:
|
||||
kb.top_m_final = top_m_final
|
||||
if enable_rerank is not None:
|
||||
kb.enable_rerank = enable_rerank
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return kb
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除知识库(级联删除所有文档和资源)"""
|
||||
# 1. 获取所有文档
|
||||
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
||||
|
||||
ops = KBManagerOps(self)
|
||||
docs = await ops.list_documents(kb_id)
|
||||
|
||||
# 2. 删除所有文档(包括文件和向量)
|
||||
for doc in docs:
|
||||
await ops.delete_document(doc.doc_id)
|
||||
|
||||
# 3. 删除向量数据库
|
||||
await self.vec_db_factory.delete_vec_db(kb_id)
|
||||
|
||||
# 4. 删除知识库记录
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if not kb:
|
||||
return False
|
||||
|
||||
await session.delete(kb)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
# ===== 文档上传 =====
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
kb_id: str,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
流程:
|
||||
1. 保存原始文件
|
||||
2. 解析文档内容
|
||||
3. 提取多媒体资源
|
||||
4. 分块处理
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
"""
|
||||
doc_id = str(uuid.uuid4())
|
||||
file_path = None
|
||||
media_paths = []
|
||||
vec_doc_ids = []
|
||||
|
||||
try:
|
||||
# 1. 保存原始文件
|
||||
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(file_content)
|
||||
|
||||
# 2. 解析文档
|
||||
parser = self.parsers.get(file_type)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件类型: {file_type}")
|
||||
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
# 3. 保存多媒体资源
|
||||
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
||||
|
||||
ops = KBManagerOps(self)
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await ops._save_media(
|
||||
kb_id=kb_id,
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 4. 文档分块
|
||||
chunks_text = await self.chunker.chunk(text_content)
|
||||
|
||||
# 5. 获取 Embedding Provider 和向量数据库
|
||||
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
|
||||
|
||||
# 6. 生成向量并存储
|
||||
saved_chunks = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
# 存储到向量数据库
|
||||
vec_doc_id = await vec_db.insert(
|
||||
content=chunk_text,
|
||||
metadata={
|
||||
"kb_id": kb_id,
|
||||
"doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
},
|
||||
)
|
||||
vec_doc_ids.append(str(vec_doc_id))
|
||||
|
||||
# 保存块元数据
|
||||
chunk = KBChunk(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
chunk_index=idx,
|
||||
content=chunk_text,
|
||||
char_count=len(chunk_text),
|
||||
vec_doc_id=str(vec_doc_id),
|
||||
)
|
||||
saved_chunks.append(chunk)
|
||||
|
||||
# 7. 保存文档元数据(事务)
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
file_path=str(file_path),
|
||||
chunk_count=len(saved_chunks),
|
||||
media_count=len(saved_media),
|
||||
)
|
||||
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
for chunk in saved_chunks:
|
||||
session.add(chunk)
|
||||
for media in saved_media:
|
||||
session.add(media)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
# 8. 更新知识库统计
|
||||
await self._update_kb_stats(kb_id)
|
||||
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
# 失败清理:删除已创建的资源
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.error(f"文档上传失败,开始清理资源: {e}")
|
||||
|
||||
# 获取知识库的向量数据库
|
||||
try:
|
||||
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
|
||||
|
||||
# 清理向量数据库
|
||||
for vec_id in vec_doc_ids:
|
||||
try:
|
||||
await vec_db.delete(vec_id)
|
||||
except Exception as ve:
|
||||
logger.warning(f"清理向量失败 {vec_id}: {ve}")
|
||||
except Exception as vfe:
|
||||
logger.error(f"获取向量数据库失败: {vfe}")
|
||||
|
||||
# 清理多媒体文件
|
||||
for media_path in media_paths:
|
||||
try:
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as me:
|
||||
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||
|
||||
# 清理文档文件
|
||||
if file_path and file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
except Exception as fe:
|
||||
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
|
||||
|
||||
# 重新抛出原始异常
|
||||
raise
|
||||
|
||||
# ===== 统计更新 =====
|
||||
|
||||
async def _update_kb_stats(self, kb_id: str):
|
||||
"""更新知识库统计信息(事务中执行)"""
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计文档数(在事务中查询)
|
||||
doc_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBDocument.id)).where(
|
||||
KBDocument.kb_id == kb_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计块数(在事务中查询)
|
||||
chunk_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# 更新知识库(在同一事务中)
|
||||
await session.execute(
|
||||
update(KnowledgeBase)
|
||||
.where(KnowledgeBase.kb_id == kb_id)
|
||||
.values(doc_count=doc_count, chunk_count=chunk_count)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -1,323 +0,0 @@
|
||||
"""知识库管理器辅助操作
|
||||
|
||||
该模块提供文档、块和多媒体的管理操作。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiofiles
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.manager import KBManager
|
||||
|
||||
|
||||
class KBManagerOps:
|
||||
"""知识库管理器辅助操作类
|
||||
|
||||
职责:
|
||||
- 文档管理操作
|
||||
- 块管理操作
|
||||
- 多媒体管理操作
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "KBManager"):
|
||||
self.manager = manager
|
||||
self.db = manager.db
|
||||
self.vec_db_factory = manager.vec_db_factory
|
||||
self.media_path = manager.media_path
|
||||
self.files_path = manager.files_path
|
||||
|
||||
# ===== 文档操作 =====
|
||||
|
||||
async def list_documents(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(KBDocument.kb_id == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBDocument.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||
"""获取文档详情"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档(级联删除块、多媒体、向量)
|
||||
|
||||
采用三阶段删除策略:
|
||||
1. 删除向量数据库中的向量(允许部分失败)
|
||||
2. 删除SQL数据库中的记录(事务保证原子性)
|
||||
3. 删除文件系统中的文件(失败不影响数据一致性)
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 0. 获取文档信息
|
||||
doc = await self.get_document(doc_id)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# 收集所有需要删除的资源
|
||||
chunks = await self.list_chunks(doc_id)
|
||||
media_list = await self.list_media(doc_id)
|
||||
|
||||
# 获取知识库的向量数据库
|
||||
embedding_provider = await self.manager._get_embedding_provider_for_kb(
|
||||
doc.kb_id
|
||||
)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(doc.kb_id, embedding_provider)
|
||||
|
||||
# ===== 第一阶段: 删除向量(可重试) =====
|
||||
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
|
||||
deleted_vec_ids = []
|
||||
failed_vec_ids = []
|
||||
|
||||
for vec_id in vec_ids_to_delete:
|
||||
try:
|
||||
await vec_db.delete(vec_id)
|
||||
deleted_vec_ids.append(vec_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_id}, {e}")
|
||||
failed_vec_ids.append(vec_id)
|
||||
|
||||
# 如果向量删除失败过多(超过50%),中止操作
|
||||
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
|
||||
logger.error(
|
||||
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
|
||||
)
|
||||
return False
|
||||
|
||||
# 记录部分失败但继续执行
|
||||
if failed_vec_ids:
|
||||
logger.warning(
|
||||
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
|
||||
)
|
||||
|
||||
# ===== 第二阶段: 删除数据库记录(事务) =====
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 删除块记录
|
||||
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
|
||||
|
||||
# 删除多媒体记录
|
||||
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
|
||||
|
||||
# 删除文档记录
|
||||
await session.execute(
|
||||
delete(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# ===== 第三阶段: 删除文件(失败不影响) =====
|
||||
# 删除多媒体文件
|
||||
for media in media_list:
|
||||
try:
|
||||
media_path = Path(media.file_path)
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
|
||||
|
||||
# 删除文档文件
|
||||
try:
|
||||
file_path = Path(doc.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
|
||||
|
||||
# ===== 更新统计 =====
|
||||
await self.manager._update_kb_stats(doc.kb_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 块操作 =====
|
||||
|
||||
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(KBChunk.doc_id == doc_id)
|
||||
.order_by(KBChunk.chunk_index)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> bool:
|
||||
"""删除单个块
|
||||
|
||||
流程:
|
||||
1. 查询块信息
|
||||
2. 删除向量
|
||||
3. 删除数据库记录
|
||||
4. 更新文档统计
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 1. 查询块信息
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
chunk = result.scalar_one_or_none()
|
||||
if not chunk:
|
||||
return False
|
||||
|
||||
doc_id = chunk.doc_id
|
||||
kb_id = chunk.kb_id
|
||||
vec_doc_id = chunk.vec_doc_id
|
||||
|
||||
# 2. 获取知识库的向量数据库并删除向量
|
||||
try:
|
||||
embedding_provider = await self.manager._get_embedding_provider_for_kb(
|
||||
kb_id
|
||||
)
|
||||
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
|
||||
await vec_db.delete(vec_doc_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
|
||||
return False
|
||||
|
||||
# 3. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 4. 更新文档统计
|
||||
await self._update_doc_stats(doc_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 多媒体操作 =====
|
||||
|
||||
async def list_media(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_media(self, media_id: str) -> bool:
|
||||
"""删除多媒体资源
|
||||
|
||||
流程:
|
||||
1. 查询媒体信息
|
||||
2. 删除数据库记录
|
||||
3. 删除文件(失败不影响)
|
||||
4. 更新文档统计
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 1. 查询媒体信息
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
media = result.scalar_one_or_none()
|
||||
if not media:
|
||||
return False
|
||||
|
||||
doc_id = media.doc_id
|
||||
file_path_str = media.file_path
|
||||
|
||||
# 2. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(KBMedia).where(KBMedia.media_id == media_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 3. 删除文件(失败不影响)
|
||||
try:
|
||||
media_path = Path(file_path_str)
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
|
||||
|
||||
# 4. 更新文档统计
|
||||
await self._update_doc_stats(doc_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 内部辅助方法 =====
|
||||
|
||||
async def _save_media(
|
||||
self,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
media_type: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
mime_type: str,
|
||||
) -> KBMedia:
|
||||
"""保存多媒体资源"""
|
||||
media_id = str(uuid.uuid4())
|
||||
ext = Path(file_name).suffix
|
||||
|
||||
# 保存文件
|
||||
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# 创建记录
|
||||
media = KBMedia(
|
||||
media_id=media_id,
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
media_type=media_type,
|
||||
file_name=file_name,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
return media
|
||||
|
||||
async def _update_doc_stats(self, doc_id: str):
|
||||
"""更新文档统计信息(事务中执行)"""
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计块数
|
||||
chunk_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
|
||||
)
|
||||
) or 0
|
||||
|
||||
# 统计多媒体数
|
||||
media_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
|
||||
)
|
||||
) or 0
|
||||
|
||||
# 更新文档
|
||||
doc = await session.scalar(
|
||||
select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
)
|
||||
if doc:
|
||||
doc.chunk_count = chunk_count
|
||||
doc.media_count = media_count
|
||||
|
||||
await session.commit()
|
||||
@@ -1,22 +1,8 @@
|
||||
"""知识库管理功能的数据模型定义
|
||||
|
||||
该模块定义了知识库系统所需的数据模型,包括:
|
||||
- KnowledgeBase: 知识库表 (存储在独立的 kb.db)
|
||||
- KBDocument: 文档表 (存储在独立的 kb.db)
|
||||
- KBChunk: 文档块表 (存储在独立的 kb.db)
|
||||
- KBMedia: 多媒体资源表 (存储在独立的 kb.db)
|
||||
- KBSessionConfig: 会话配置表 (存储在独立的 kb.db)
|
||||
|
||||
注意:
|
||||
- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db)
|
||||
- 与主数据库 (astrbot.db) 完全解耦
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
@@ -49,7 +35,6 @@ class KnowledgeBase(SQLModel, table=True):
|
||||
top_k_dense: Optional[int] = Field(default=50, nullable=True)
|
||||
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
|
||||
top_m_final: Optional[int] = Field(default=5, nullable=True)
|
||||
enable_rerank: Optional[bool] = Field(default=False, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
@@ -58,6 +43,13 @@ class KnowledgeBase(SQLModel, table=True):
|
||||
doc_count: int = Field(default=0, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"kb_name",
|
||||
name="uix_kb_name",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KBDocument(SQLModel, table=True):
|
||||
"""文档表
|
||||
|
||||
@@ -51,10 +51,10 @@ class PDFParser(BaseParser):
|
||||
continue
|
||||
|
||||
resources = page["/Resources"]
|
||||
if not resources or "/XObject" not in resources:
|
||||
if not resources or "/XObject" not in resources: # type: ignore
|
||||
continue
|
||||
|
||||
xobjects = resources["/XObject"].get_object()
|
||||
xobjects = resources["/XObject"].get_object() # type: ignore
|
||||
if not xobjects:
|
||||
continue
|
||||
|
||||
|
||||
@@ -6,12 +6,15 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB, Result
|
||||
from ..vec_db_factory import VecDBFactory
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from ..kb_helper import KBHelper
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
@@ -37,10 +40,9 @@ class RetrievalManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vec_db_factory, # VecDBFactory
|
||||
sparse_retriever: SparseRetriever,
|
||||
rank_fusion: RankFusion,
|
||||
kb_db: KBDatabase,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
):
|
||||
"""初始化检索管理器
|
||||
|
||||
@@ -50,21 +52,16 @@ class RetrievalManager:
|
||||
rank_fusion: 结果融合器
|
||||
kb_db: 知识库数据库实例
|
||||
"""
|
||||
self.vec_db_factory = vec_db_factory
|
||||
self.sparse_retriever = sparse_retriever
|
||||
self.rank_fusion = rank_fusion
|
||||
self.kb_db = kb_db
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
vec_db_factory: VecDBFactory,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k_dense: int = 50,
|
||||
top_k_sparse: int = 50,
|
||||
top_n_fusion: int = 20,
|
||||
kb_id_helper_map: dict[str, KBHelper],
|
||||
top_m_final: int = 5,
|
||||
rerank_provider: RerankProvider | None = None,
|
||||
) -> List[RetrievalResult]:
|
||||
"""混合检索
|
||||
|
||||
@@ -77,35 +74,52 @@ class RetrievalManager:
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k_dense: 稠密检索返回数量
|
||||
top_k_sparse: 稀疏检索返回数量
|
||||
top_n_fusion: 融合后返回数量
|
||||
top_m_final: 最终返回数量
|
||||
enable_rerank: 是否启用 Rerank
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 检索结果列表
|
||||
"""
|
||||
if not kb_ids:
|
||||
return []
|
||||
|
||||
kb_options: dict = {}
|
||||
new_kb_ids = []
|
||||
for kb_id in kb_ids:
|
||||
kb_helper = kb_id_helper_map.get(kb_id)
|
||||
if kb_helper:
|
||||
kb = kb_helper.kb
|
||||
kb_options[kb_id] = {
|
||||
"top_k_dense": kb.top_k_dense or 50,
|
||||
"top_k_sparse": kb.top_k_sparse or 50,
|
||||
"top_m_final": kb.top_m_final or 5,
|
||||
"vec_db": kb_helper.vec_db,
|
||||
}
|
||||
new_kb_ids.append(kb_id)
|
||||
else:
|
||||
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
|
||||
|
||||
kb_ids = new_kb_ids
|
||||
|
||||
# 1. 稠密检索
|
||||
dense_results = await self._dense_retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_dense,
|
||||
vec_db=vec_db,
|
||||
kb_options=kb_options,
|
||||
)
|
||||
|
||||
# 2. 稀疏检索
|
||||
sparse_results = await self.sparse_retriever.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_sparse,
|
||||
kb_options=kb_options,
|
||||
)
|
||||
|
||||
# 3. 结果融合
|
||||
fused_results = await self.rank_fusion.fuse(
|
||||
dense_results=dense_results,
|
||||
sparse_results=sparse_results,
|
||||
top_k=top_n_fusion,
|
||||
top_k=kb_options.get("top_k_fusion", 20),
|
||||
)
|
||||
|
||||
# 4. 转换为 RetrievalResult (获取元数据)
|
||||
@@ -130,24 +144,27 @@ class RetrievalManager:
|
||||
)
|
||||
|
||||
# 5. Rerank
|
||||
if rerank_provider and retrieval_results:
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
if vec_db and vec_db.rerank_provider:
|
||||
first_rerank = vec_db.rerank_provider
|
||||
break
|
||||
if first_rerank and retrieval_results:
|
||||
retrieval_results = await self._rerank(
|
||||
query=query,
|
||||
results=retrieval_results,
|
||||
top_k=top_m_final,
|
||||
rerank_provider=rerank_provider,
|
||||
rerank_provider=first_rerank,
|
||||
)
|
||||
else:
|
||||
retrieval_results = retrieval_results[:top_m_final]
|
||||
|
||||
return retrieval_results
|
||||
return retrieval_results[:top_m_final]
|
||||
|
||||
async def _dense_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int,
|
||||
vec_db: BaseVecDB,
|
||||
kb_options: dict,
|
||||
):
|
||||
"""稠密检索 (向量相似度)
|
||||
|
||||
@@ -162,13 +179,17 @@ class RetrievalManager:
|
||||
List[Result]: 检索结果列表
|
||||
"""
|
||||
all_results: list[Result] = []
|
||||
|
||||
for kb_id in kb_ids:
|
||||
if kb_id not in kb_options:
|
||||
continue
|
||||
try:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
dense_k = int(kb_options[kb_id]["top_k_dense"])
|
||||
vec_results = await vec_db.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
fetch_k=top_k * 2,
|
||||
k=dense_k,
|
||||
fetch_k=dense_k * 2,
|
||||
rerank=False, # 稠密检索阶段不进行 rerank
|
||||
metadata_filters={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
@@ -181,7 +202,7 @@ class RetrievalManager:
|
||||
|
||||
# 按相似度排序并返回 top_k
|
||||
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
||||
return all_results[:top_k]
|
||||
return all_results[: len(all_results) // len(kb_ids)]
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
|
||||
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class RankFusion:
|
||||
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBDatabase, k: int = 60):
|
||||
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
|
||||
"""初始化结果融合器
|
||||
|
||||
Args:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import List
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,7 +30,7 @@ class SparseRetriever:
|
||||
- 使用 BM25 算法计算相关度
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBDatabase):
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
"""初始化稀疏检索器
|
||||
|
||||
Args:
|
||||
@@ -43,14 +43,14 @@ class SparseRetriever:
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int = 50,
|
||||
kb_options: dict,
|
||||
) -> List[SparseResult]:
|
||||
"""执行稀疏检索
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量
|
||||
kb_options: 每个知识库的检索选项
|
||||
|
||||
Returns:
|
||||
List[SparseResult]: 检索结果列表
|
||||
@@ -87,4 +87,4 @@ class SparseRetriever:
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
return results[:top_k]
|
||||
return results[: len(results) // len(kb_ids)]
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
"""向量数据库工厂
|
||||
|
||||
负责为每个知识库创建和管理独立的向量数据库实例。
|
||||
|
||||
架构说明:
|
||||
- 每个知识库拥有独立的向量数据库实例
|
||||
- 向量数据库文件以 kb_id 命名
|
||||
- 工厂类负责实例的创建、缓存和生命周期管理
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
|
||||
|
||||
class VecDBFactory:
|
||||
"""向量数据库工厂
|
||||
|
||||
职责:
|
||||
- 为每个知识库创建独立的向量数据库实例
|
||||
- 缓存已创建的实例以提高性能
|
||||
- 管理向量数据库的生命周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_base_path: str,
|
||||
):
|
||||
"""初始化向量数据库工厂
|
||||
|
||||
Args:
|
||||
storage_base_path: 向量数据库存储基础路径
|
||||
"""
|
||||
self.storage_base_path = Path(storage_base_path)
|
||||
self._instances: Dict[str, BaseVecDB] = {}
|
||||
|
||||
# 确保基础路径存在
|
||||
self.storage_base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def get_vec_db(
|
||||
self, kb_id: str, embedding_provider: EmbeddingProvider
|
||||
) -> BaseVecDB:
|
||||
"""获取或创建指定知识库的向量数据库实例
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
embedding_provider: Embedding Provider 实例
|
||||
|
||||
Returns:
|
||||
BaseVecDB: 向量数据库实例
|
||||
"""
|
||||
# 如果已经创建过,直接返回缓存的实例
|
||||
if kb_id in self._instances:
|
||||
return self._instances[kb_id]
|
||||
|
||||
# 创建新实例
|
||||
vec_db = await self._create_vec_db(kb_id, embedding_provider)
|
||||
self._instances[kb_id] = vec_db
|
||||
|
||||
logger.debug(f"创建知识库 {kb_id} 的向量数据库实例")
|
||||
|
||||
return vec_db
|
||||
|
||||
async def _create_vec_db(
|
||||
self, kb_id: str, embedding_provider: EmbeddingProvider
|
||||
) -> BaseVecDB:
|
||||
"""创建向量数据库实例
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
embedding_provider: Embedding Provider 实例
|
||||
|
||||
Returns:
|
||||
BaseVecDB: 向量数据库实例
|
||||
"""
|
||||
# 为每个知识库创建独立的存储路径
|
||||
kb_storage_path = self.storage_base_path / kb_id
|
||||
kb_storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
doc_store_path = str(kb_storage_path / "documents.db")
|
||||
index_store_path = str(kb_storage_path / "index.faiss")
|
||||
|
||||
vec_db = FaissVecDB(
|
||||
doc_store_path=doc_store_path,
|
||||
index_store_path=index_store_path,
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
|
||||
await vec_db.initialize()
|
||||
|
||||
return vec_db
|
||||
|
||||
async def delete_vec_db(self, kb_id: str) -> bool:
|
||||
"""删除指定知识库的向量数据库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
# 关闭并移除缓存的实例
|
||||
if kb_id in self._instances:
|
||||
try:
|
||||
await self._instances[kb_id].close()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
|
||||
|
||||
del self._instances[kb_id]
|
||||
|
||||
# 删除文件系统中的向量数据库文件
|
||||
kb_storage_path = self.storage_base_path / kb_id
|
||||
if kb_storage_path.exists():
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(kb_storage_path)
|
||||
logger.info(f"已删除知识库 {kb_id} 的向量数据库文件")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量数据库文件失败 ({kb_id}): {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close_all(self):
|
||||
"""关闭所有向量数据库实例"""
|
||||
for kb_id, vec_db in list(self._instances.items()):
|
||||
try:
|
||||
await vec_db.close()
|
||||
logger.debug(f"已关闭知识库 {kb_id} 的向量数据库")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
|
||||
|
||||
self._instances.clear()
|
||||
|
||||
def has_instance(self, kb_id: str) -> bool:
|
||||
"""检查是否已创建指定知识库的向量数据库实例
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否已创建实例
|
||||
"""
|
||||
return kb_id in self._instances
|
||||
|
||||
def get_cached_instance(self, kb_id: str) -> Optional[BaseVecDB]:
|
||||
"""获取已缓存的向量数据库实例(不创建新实例)
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
|
||||
Returns:
|
||||
Optional[BaseVecDB]: 向量数据库实例,如果不存在则返回 None
|
||||
"""
|
||||
return self._instances.get(kb_id)
|
||||
@@ -15,13 +15,15 @@ async def inject_kb_context(
|
||||
p_ctx: Pipeline context
|
||||
req: Provider request
|
||||
"""
|
||||
kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector()
|
||||
if not kb_injector:
|
||||
kb_mgr = p_ctx.plugin_manager.context.kb_manager
|
||||
kb_names = p_ctx.astrbot_config.get("kb_names", [])
|
||||
|
||||
if not kb_names:
|
||||
return
|
||||
kb_context = await kb_injector.retrieve_and_inject(
|
||||
unified_msg_origin=umo,
|
||||
|
||||
kb_context = await kb_mgr.retrieve(
|
||||
query=req.prompt,
|
||||
top_k=top_k,
|
||||
kb_names=kb_names,
|
||||
)
|
||||
if not kb_context:
|
||||
return
|
||||
|
||||
@@ -19,7 +19,7 @@ from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,8 @@
|
||||
"delete": "Delete",
|
||||
"preview": "Preview",
|
||||
"search": "Search Chunks",
|
||||
"searchPlaceholder": "Enter keywords to search chunks..."
|
||||
"searchPlaceholder": "Enter keywords to search chunks...",
|
||||
"showing": "Showing"
|
||||
},
|
||||
"edit": {
|
||||
"title": "Edit Chunk",
|
||||
|
||||
@@ -21,7 +21,8 @@
|
||||
"delete": "删除",
|
||||
"preview": "预览",
|
||||
"search": "搜索分块",
|
||||
"searchPlaceholder": "输入关键词搜索分块内容..."
|
||||
"searchPlaceholder": "输入关键词搜索分块内容...",
|
||||
"showing": "显示"
|
||||
},
|
||||
"edit": {
|
||||
"title": "编辑分块",
|
||||
|
||||
@@ -82,7 +82,7 @@
|
||||
<v-card-title class="d-flex align-center pa-4">
|
||||
<span>{{ t('chunks.title') }}</span>
|
||||
<v-chip class="ml-2" size="small" variant="tonal">
|
||||
{{ chunks.length }} {{ t('chunks.title') }}
|
||||
{{ totalChunks }} {{ t('chunks.title') }}
|
||||
</v-chip>
|
||||
<v-spacer />
|
||||
<v-text-field
|
||||
@@ -104,7 +104,8 @@
|
||||
:headers="headers"
|
||||
:items="filteredChunks"
|
||||
:loading="loadingChunks"
|
||||
:items-per-page="10"
|
||||
:items-per-page="pageSize"
|
||||
hide-default-footer
|
||||
>
|
||||
<template #item.chunk_index="{ item }">
|
||||
<v-chip size="small" variant="tonal" color="primary">
|
||||
@@ -132,20 +133,6 @@
|
||||
color="info"
|
||||
@click="viewChunk(item)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-pencil"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="primary"
|
||||
@click="editChunk(item)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-delete"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="error"
|
||||
@click="confirmDeleteChunk(item)"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<template #no-data>
|
||||
@@ -155,6 +142,31 @@
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
|
||||
|
||||
<!-- 自定义分页器 -->
|
||||
<div v-if="!searchQuery && totalChunks > 0" class="pa-4 d-flex align-center justify-space-between">
|
||||
<div class="text-caption text-medium-emphasis">
|
||||
{{ t('chunks.showing') }} {{ (page - 1) * pageSize + 1 }} - {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }}
|
||||
</div>
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-select
|
||||
v-model="pageSize"
|
||||
:items="[10, 25, 50, 100]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
style="width: 100px"
|
||||
@update:model-value="handlePageSizeChange"
|
||||
/>
|
||||
<v-pagination
|
||||
v-model="page"
|
||||
:length="Math.ceil(totalChunks / pageSize)"
|
||||
:total-visible="5"
|
||||
@update:model-value="handlePageChange"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</div>
|
||||
@@ -212,71 +224,6 @@
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 编辑分块对话框 -->
|
||||
<v-dialog v-model="showEditDialog" max-width="800px" persistent scrollable>
|
||||
<v-card>
|
||||
<v-card-title class="pa-4">
|
||||
<span>{{ t('edit.title') }}</span>
|
||||
<v-spacer />
|
||||
<v-btn icon="mdi-close" variant="text" @click="closeEditDialog" />
|
||||
</v-card-title>
|
||||
<v-divider />
|
||||
<v-card-text class="pa-6">
|
||||
<v-textarea
|
||||
v-model="editForm.content"
|
||||
:label="t('edit.content')"
|
||||
variant="outlined"
|
||||
rows="15"
|
||||
auto-grow
|
||||
/>
|
||||
</v-card-text>
|
||||
<v-divider />
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="closeEditDialog">
|
||||
{{ t('edit.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="saveChunk"
|
||||
:loading="saving"
|
||||
>
|
||||
{{ t('edit.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 删除确认对话框 -->
|
||||
<v-dialog v-model="showDeleteDialog" max-width="450px">
|
||||
<v-card>
|
||||
<v-card-title class="pa-4 text-h6">{{ t('delete.title') }}</v-card-title>
|
||||
<v-divider />
|
||||
<v-card-text class="pa-6">
|
||||
<p>{{ t('delete.confirmText') }}</p>
|
||||
<v-alert type="warning" variant="tonal" density="compact" class="mt-4">
|
||||
{{ t('delete.warning') }}
|
||||
</v-alert>
|
||||
</v-card-text>
|
||||
<v-divider />
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="showDeleteDialog = false">
|
||||
{{ t('delete.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="error"
|
||||
variant="elevated"
|
||||
@click="deleteChunk"
|
||||
:loading="deleting"
|
||||
>
|
||||
{{ t('delete.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar v-model="snackbar.show" :color="snackbar.color">
|
||||
{{ snackbar.text }}
|
||||
@@ -299,16 +246,16 @@ const docId = ref(route.params.docId as string)
|
||||
// 状态
|
||||
const loading = ref(true)
|
||||
const loadingChunks = ref(false)
|
||||
const saving = ref(false)
|
||||
const deleting = ref(false)
|
||||
const document = ref<any>({})
|
||||
const chunks = ref<any[]>([])
|
||||
const searchQuery = ref('')
|
||||
const showViewDialog = ref(false)
|
||||
const showEditDialog = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const selectedChunk = ref<any>(null)
|
||||
const deleteTarget = ref<any>(null)
|
||||
|
||||
// 分页状态
|
||||
const page = ref(1)
|
||||
const pageSize = ref(10)
|
||||
const totalChunks = ref(0)
|
||||
|
||||
const snackbar = ref({
|
||||
show: false,
|
||||
@@ -322,17 +269,12 @@ const showSnackbar = (text: string, color: string = 'success') => {
|
||||
snackbar.value.show = true
|
||||
}
|
||||
|
||||
// 编辑表单
|
||||
const editForm = ref({
|
||||
content: ''
|
||||
})
|
||||
|
||||
// 表格列
|
||||
const headers = [
|
||||
{ title: t('chunks.index'), key: 'chunk_index', width: 100 },
|
||||
{ title: t('chunks.content'), key: 'content', sortable: false },
|
||||
{ title: t('chunks.charCount'), key: 'char_count', width: 150 },
|
||||
{ title: t('chunks.actions'), key: 'actions', sortable: false, align: 'end', width: 150 }
|
||||
{ title: t('chunks.actions'), key: 'actions', sortable: false, width: 150 }
|
||||
]
|
||||
|
||||
// 过滤分块
|
||||
@@ -349,7 +291,7 @@ const loadDocument = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/kb/document/get', {
|
||||
params: { doc_id: docId.value }
|
||||
params: { doc_id: docId.value, kb_id: kbId.value }
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
document.value = response.data.data
|
||||
@@ -367,10 +309,16 @@ const loadChunks = async () => {
|
||||
loadingChunks.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/kb/chunk/list', {
|
||||
params: { doc_id: docId.value }
|
||||
params: {
|
||||
doc_id: docId.value,
|
||||
kb_id: kbId.value,
|
||||
page: page.value,
|
||||
page_size: pageSize.value
|
||||
}
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
chunks.value = response.data.data.items || []
|
||||
totalChunks.value = response.data.data.items.length
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load chunks:', error)
|
||||
@@ -380,84 +328,24 @@ const loadChunks = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// 处理分页变化
|
||||
const handlePageChange = (newPage: number) => {
|
||||
page.value = newPage
|
||||
loadChunks()
|
||||
}
|
||||
|
||||
const handlePageSizeChange = (newPageSize: number) => {
|
||||
pageSize.value = newPageSize
|
||||
page.value = 1
|
||||
loadChunks()
|
||||
}
|
||||
|
||||
// 查看分块
|
||||
const viewChunk = (chunk: any) => {
|
||||
selectedChunk.value = chunk
|
||||
showViewDialog.value = true
|
||||
}
|
||||
|
||||
// 编辑分块
|
||||
const editChunk = (chunk: any) => {
|
||||
selectedChunk.value = chunk
|
||||
editForm.value.content = chunk.content
|
||||
showEditDialog.value = true
|
||||
}
|
||||
|
||||
// 关闭编辑对话框
|
||||
const closeEditDialog = () => {
|
||||
showEditDialog.value = false
|
||||
selectedChunk.value = null
|
||||
editForm.value.content = ''
|
||||
}
|
||||
|
||||
// 保存分块
|
||||
const saveChunk = async () => {
|
||||
if (!selectedChunk.value) return
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
const response = await axios.post('/api/kb/chunk/update', {
|
||||
chunk_id: selectedChunk.value.chunk_id,
|
||||
content: editForm.value.content
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
showSnackbar(t('edit.saveSuccess'))
|
||||
closeEditDialog()
|
||||
await loadChunks()
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('edit.saveFailed'), 'error')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to save chunk:', error)
|
||||
showSnackbar(t('edit.saveFailed'), 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 确认删除分块
|
||||
const confirmDeleteChunk = (chunk: any) => {
|
||||
deleteTarget.value = chunk
|
||||
showDeleteDialog.value = true
|
||||
}
|
||||
|
||||
// 删除分块
|
||||
const deleteChunk = async () => {
|
||||
if (!deleteTarget.value) return
|
||||
|
||||
deleting.value = true
|
||||
try {
|
||||
const response = await axios.post('/api/kb/chunk/delete', {
|
||||
chunk_id: deleteTarget.value.chunk_id
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
showSnackbar(t('delete.deleteSuccess'))
|
||||
showDeleteDialog.value = false
|
||||
await loadChunks()
|
||||
await loadDocument()
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('delete.deleteFailed'), 'error')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to delete chunk:', error)
|
||||
showSnackbar(t('delete.deleteFailed'), 'error')
|
||||
} finally {
|
||||
deleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 工具函数
|
||||
const getFileIcon = (fileType: string) => {
|
||||
const type = fileType?.toLowerCase() || ''
|
||||
@@ -582,6 +470,10 @@ onMounted(() => {
|
||||
font-family: 'Consolas', 'Monaco', monospace;
|
||||
}
|
||||
|
||||
.gap-2 {
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
.document-detail-page {
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
</v-tabs>
|
||||
|
||||
<!-- 标签页内容 -->
|
||||
<v-window v-model="activeTab">
|
||||
<v-window v-model="activeTab" style="padding: 8px;">
|
||||
<!-- 概览 -->
|
||||
<v-window-item value="overview">
|
||||
<v-row>
|
||||
@@ -163,7 +163,7 @@
|
||||
|
||||
<!-- 知识库检索 -->
|
||||
<v-window-item value="retrieval">
|
||||
<RetrievalTab :kb-id="kbId" />
|
||||
<RetrievalTab :kb-id="kbId" :kb-name="kb.kb_name"/>
|
||||
</v-window-item>
|
||||
|
||||
<!-- 使用会话 -->
|
||||
@@ -329,12 +329,11 @@ onMounted(() => {
|
||||
padding: 24px;
|
||||
text-align: center;
|
||||
border-radius: 12px;
|
||||
background: rgba(var(--v-theme-surface-variant), 0.3);
|
||||
background: rgba(var(--v-theme-surface-variant), 0.1);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.stat-box:hover {
|
||||
transform: translateY(-4px);
|
||||
background: rgba(var(--v-theme-surface-variant), 0.5);
|
||||
}
|
||||
|
||||
@@ -342,12 +341,10 @@ onMounted(() => {
|
||||
font-size: 2rem;
|
||||
font-weight: 600;
|
||||
margin-top: 8px;
|
||||
color: rgb(var(--v-theme-on-surface));
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 0.875rem;
|
||||
color: rgb(var(--v-theme-on-surface-variant));
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,10 +7,16 @@
|
||||
</v-card-subtitle>
|
||||
|
||||
<v-divider />
|
||||
<v-progress-linear
|
||||
v-if="loading"
|
||||
indeterminate
|
||||
color="primary"
|
||||
height="2"
|
||||
/>
|
||||
|
||||
<v-card-text class="pa-6">
|
||||
<!-- 查询输入区域 -->
|
||||
<v-row>
|
||||
<v-row class="mb-4">
|
||||
<v-col cols="12" md="8">
|
||||
<v-textarea
|
||||
v-model="query"
|
||||
@@ -34,20 +40,7 @@
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
persistent-hint
|
||||
class="mb-3"
|
||||
/>
|
||||
|
||||
<v-checkbox
|
||||
v-model="enableRerank"
|
||||
:label="t('retrieval.enableRerank')"
|
||||
:hint="t('retrieval.enableRerankHint')"
|
||||
color="primary"
|
||||
density="compact"
|
||||
persistent-hint
|
||||
/>
|
||||
<v-alert v-if="enableRerank" type="info" variant="tonal" class="mt-2" density="compact">
|
||||
如果没有配置重排序模型提供商,将跳过重排序步骤
|
||||
</v-alert>
|
||||
</v-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -143,7 +136,8 @@ import { useModuleI18n } from '@/i18n/composables'
|
||||
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
|
||||
|
||||
const props = defineProps<{
|
||||
kbId: string
|
||||
kbId: string,
|
||||
kbName: string,
|
||||
}>()
|
||||
|
||||
// 状态
|
||||
@@ -179,7 +173,7 @@ const performRetrieval = async () => {
|
||||
try {
|
||||
const response = await axios.post('/api/kb/retrieve', {
|
||||
query: query.value,
|
||||
kb_ids: [props.kbId],
|
||||
kb_names: [props.kbName],
|
||||
top_k: topK.value,
|
||||
enable_rerank: enableRerank.value
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user