This commit is contained in:
Soulter
2025-10-23 21:20:41 +08:00
parent e3aa1315ae
commit 59df244173
26 changed files with 1141 additions and 2664 deletions
+2 -20
View File
@@ -134,27 +134,9 @@ DEFAULT_CONFIG = {
"persona": [], # deprecated
"timezone": "Asia/Shanghai",
"callback_api_base": "",
"default_kb_collection": "", # 默认知识库名称
"default_kb_collection": "", # 默认知识库名称, 已经过时
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
"knowledge_base": {
"enabled": False, # 默认禁用,用户需要主动启用
"embedding_provider_id": "", # 嵌入模型提供商 ID (为空时自动选择第一个)
"rerank_provider_id": "", # 重排序模型提供商 ID (为空时自动选择第一个)
"storage": {
"files_path": "data/knowledge_base", # 文件存储路径
"vector_db_path": "data/knowledge_base/vectors", # 向量数据库路径
},
"chunking": {
"chunk_size": 512, # 文档块大小(字符数)
"chunk_overlap": 50, # 文档块重叠大小(字符数)
},
"retrieval": {
"top_k_dense": 50, # 密集检索返回结果数
"top_k_sparse": 50, # 稀疏检索返回结果数
"top_m_final": 5, # 最终融合后返回的结果数
"enable_rerank": True, # 是否启用重排序
},
},
"kb_names": [], # 默认知识库名称列表
}
+2 -8
View File
@@ -34,7 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
class AstrBotCoreLifecycle:
@@ -112,9 +112,7 @@ class AstrBotCoreLifecycle:
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.provider_manager
)
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
# 初始化提供给插件的上下文
self.star_context = Context(
@@ -141,10 +139,6 @@ class AstrBotCoreLifecycle:
await self.kb_manager.initialize()
# 注册知识库会话生命周期钩子(零侵入级联清理)
if self.kb_manager.is_initialized:
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
-34
View File
@@ -1,34 +0,0 @@
"""
知识库管理模块
提供文档上传、解析、分块、向量化、检索等功能
"""
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
# 注意: 以下导入在对应模块实现后取消注释
from .database import KBDatabase
from .manager import KBManager
from .manager_ops import KBManagerOps
from .session_config_db import SessionConfigDB
# from .injector import KnowledgeBaseInjector
__all__ = [
"KnowledgeBase",
"KBDocument",
"KBChunk",
"KBMedia",
"KBSessionConfig",
"KBDatabase",
"SessionConfigDB",
"KBManager",
"KBManagerOps",
# "KnowledgeBaseInjector",
]
-183
View File
@@ -1,183 +0,0 @@
"""知识库数据库操作类
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
注意:
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
- 会话配置也存储在此数据库中,会话ID来源于主数据库
"""
from typing import Optional
from sqlalchemy import func, select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KnowledgeBase,
)
class KBDatabase:
"""知识库数据库操作类
职责:
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
- 统一异常处理
注意:
- 该类操作独立的知识库数据库 (kb.db)
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化知识库数据库操作类
Args:
kb_db: 知识库独立数据库实例,而非主数据库
"""
self.db = kb_db
# ===== 知识库查询 =====
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
"""根据 ID 获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
"""根据名称获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KnowledgeBase.id))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
"""根据 ID 获取文档"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 块查询 =====
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
"""根据 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
"""根据知识库 ID 列表获取所有块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
"""根据向量文档 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
"""获取块及其关联的文档和知识库元数据"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk, KBDocument, KnowledgeBase)
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
.where(KBChunk.chunk_id == chunk_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
chunk, doc, kb = row
return {
"chunk": chunk,
"document": doc,
"knowledge_base": kb,
}
async def list_chunks_by_doc(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.offset(offset)
.limit(limit)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
"""根据 ID 获取多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
-112
View File
@@ -1,112 +0,0 @@
"""知识库上下文注入器
负责检索相关知识并格式化为 LLM 可用的上下文文本
"""
from typing import List, Optional
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.retrieval.manager import (
RetrievalManager,
RetrievalResult,
)
from .vec_db_factory import VecDBFactory
class KnowledgeBaseInjector:
"""知识库上下文注入器
职责:
- 检索相关知识
- 格式化为上下文文本
- 注入到 LLM Prompt
"""
def __init__(
self,
kb_db: KBDatabase,
vec_db_factory: VecDBFactory,
retrieval_manager: RetrievalManager,
):
"""初始化知识库上下文注入器
Args:
kb_db: 知识库数据库实例
retrieval_manager: 检索管理器实例
"""
self.kb_db = kb_db
self.vec_db_factory = vec_db_factory
self.retrieval_manager = retrieval_manager
async def retrieve_and_inject(
self,
kb_ids: list[str],
query: str,
top_k: int = 5,
) -> Optional[dict]:
"""检索并注入知识库上下文
Args:
query: 用户查询
top_k: 返回结果数量
Returns:
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
{
"context_text": str, # 格式化的上下文文本
"results": List[dict], # 原始检索结果列表
}
"""
# 2. 检索知识
results = await self.retrieval_manager.retrieve(
vec_db_factory=self.vec_db_factory,
query=query,
kb_ids=kb_ids,
top_m_final=top_k,
)
if not results:
return None
# 3. 格式化上下文
context_text = self._format_context(results)
# 4. 转换结果为字典格式
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: List[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
+383
View File
@@ -0,0 +1,383 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlmodel import SQLModel, col, desc
from sqlalchemy import text, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KnowledgeBase,
)
from typing import Optional
class KBSQLiteDatabase:
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(SQLModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建块表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
"ON kb_chunks(chunk_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
"ON kb_chunks(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
"ON kb_chunks(vec_doc_id)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
# 创建会话配置表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
"ON kb_session_config(scope_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
"ON kb_session_config(scope)"
)
)
await session.commit()
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
"""根据 ID 获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
"""根据名称获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(desc(KnowledgeBase.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KnowledgeBase.id)))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
"""根据 ID 获取文档"""
async with self.get_db() as session:
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.get_db() as session:
stmt = (
select(KBDocument)
.where(col(KBDocument.kb_id) == kb_id)
.offset(offset)
.limit(limit)
.order_by(desc(KBDocument.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KBDocument.id))).where(
col(KBDocument.kb_id) == kb_id
)
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 块查询 =====
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
"""根据 ID 获取块"""
async with self.get_db() as session:
stmt = select(KBChunk).where(col(KBChunk.chunk_id) == chunk_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
"""根据知识库 ID 列表获取所有块"""
async with self.get_db() as session:
stmt = select(KBChunk).where(col(KBChunk.kb_id).in_(kb_ids))
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
"""根据向量文档 ID 获取块"""
async with self.get_db() as session:
stmt = select(KBChunk).where(col(KBChunk.vec_doc_id) == vec_doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunks_by_doc_id(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[KBChunk]:
"""根据文档 ID 获取所有块"""
async with self.get_db() as session:
stmt = (
select(KBChunk)
.where(col(KBChunk.doc_id) == doc_id)
.offset(offset)
.limit(limit)
.order_by(col(KBChunk.chunk_index))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
"""获取块及其关联的文档和知识库元数据"""
async with self.get_db() as session:
stmt = (
select(KBChunk, KBDocument, KnowledgeBase)
.join(KBDocument, col(KBChunk.doc_id) == col(KBDocument.doc_id))
.join(KnowledgeBase, col(KBChunk.kb_id) == col(KnowledgeBase.kb_id))
.where(col(KBChunk.chunk_id) == chunk_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
chunk, doc, kb = row
return {
"chunk": chunk,
"document": doc,
"knowledge_base": kb,
}
async def list_chunks_by_doc(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.get_db() as session:
stmt = (
select(KBChunk)
.where(col(KBChunk.doc_id) == doc_id)
.offset(offset)
.limit(limit)
.order_by(col(KBChunk.chunk_index))
)
result = await session.execute(stmt)
return list(result.scalars().all())
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
"""根据 ID 获取多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_kb_stats(self, kb_id: str):
"""更新知识库统计信息"""
async with self.get_db() as session:
async with session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=select(func.count(col(KBChunk.id)))
.where(col(KBChunk.kb_id) == kb_id)
.scalar_subquery(),
)
)
await session.execute(update_stmt)
await session.commit()
+248
View File
@@ -0,0 +1,248 @@
import uuid
import aiofiles
from pathlib import Path
from .models import KnowledgeBase, KBDocument, KBChunk, KBMedia
from .kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.provider.manager import ProviderManager
from .parsers.base import BaseParser
from .chunking.base import BaseChunker
from astrbot.core import logger
class KBHelper:
vec_db: BaseVecDB
def __init__(
self,
kb_db: KBSQLiteDatabase,
kb: KnowledgeBase,
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
parsers: dict[str, BaseParser],
):
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
self.kb_root_dir = kb_root_dir
self.parsers = parsers
self.chunker = chunker
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
)
return ep
async def get_rp(self) -> RerankProvider | None:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
)
return rp
async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep = await self.get_ep()
rp = await self.get_rp()
vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
index_store_path=str(self.kb_dir / "index.faiss"),
embedding_provider=ep,
rerank_provider=rp,
)
await vec_db.initialize()
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
await self.terminate()
if self.kb_dir.exists():
for item in self.kb_dir.iterdir():
if item.is_file():
item.unlink()
self.kb_dir.rmdir()
async def terminate(self):
if self.vec_db:
await self.vec_db.close()
async def upload_document(
self,
file_name: str,
file_content: bytes,
file_type: str,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源 (TODO)
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
vec_doc_ids = []
file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
async with aiofiles.open(file_path, "wb") as f:
await f.write(file_content)
try:
parser = self.parsers.get(file_type)
if not parser:
raise ValueError(f"不支持的文件类型: {file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
# 保存媒体文件
saved_media = []
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 分块并生成向量
saved_chunks = []
chunks_text = await self.chunker.chunk(text_content)
for idx, chunk_text in enumerate(chunks_text):
vec_doc_id = await self.vec_db.insert(
content=chunk_text,
metadata={
"kb_id": self.kb.kb_id,
"doc_id": doc_id,
"chunk_index": idx,
},
)
vec_doc_ids.append(str(vec_doc_id))
chunk = KBChunk(
doc_id=doc_id,
kb_id=self.kb.kb_id,
chunk_index=idx,
content=chunk_text,
char_count=len(chunk_text),
vec_doc_id=str(vec_doc_id),
)
saved_chunks.append(chunk)
# 保存文档和块的元数据
doc = KBDocument(
doc_id=doc_id,
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
file_path=str(file_path),
chunk_count=len(saved_chunks),
media_count=0,
)
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for chunk in saved_chunks:
session.add(chunk)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id)
return doc
except Exception as e:
logger.error(f"上传文档失败: {e}")
if file_path.exists():
file_path.unlink()
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
raise e
async def list_documents(
self, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
return docs
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取单个文档"""
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def _save_media(
self,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=self.kb.kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media
@@ -1,205 +0,0 @@
"""
知识库管理器
负责知识库模块的初始化、配置和资源管理
架构说明:
- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db)
- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联
"""
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
from .injector import KnowledgeBaseInjector
from .retrieval.manager import RetrievalManager
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_sqlite import KBSQLiteDatabase
from .database import KBDatabase
from .vec_db_factory import VecDBFactory
from .manager import KBManager
from .parsers.text_parser import TextParser
from .parsers.pdf_parser import PDFParser
from .chunking.fixed_size import FixedSizeChunker
class KnowledgeBaseManager:
"""知识库管理器
职责:
- 知识库模块的初始化
- Embedding Provider 和 Rerank Provider 的选择
- 各个子组件的协调管理
- 注册会话删除回调,实现级联清理
架构说明:
- 知识库数据存储在独立数据库 (kb.db)
- 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库
- 通过回调机制实现与主数据库的生命周期同步
"""
kb_db: KBSQLiteDatabase
vec_db_factory: VecDBFactory
kb_database: KBDatabase
kb_manager: KBManager
retrieval_manager: RetrievalManager
kb_injector: KnowledgeBaseInjector
def __init__(
self,
config: dict,
provider_manager: ProviderManager,
):
"""初始化知识库管理器
Args:
config: 配置字典
provider_manager: Provider 管理器
"""
self.config = config.get("knowledge_base", {})
self.provider_manager = provider_manager
self._initialized = False
self._session_deleted_callback_registered = False
async def initialize(self):
"""初始化知识库模块"""
if not self.config.get("enabled", False):
logger.info("知识库功能未启用")
return
try:
logger.info("正在初始化知识库模块...")
# 初始化数据库
await self._init_kb_database()
# 初始化向量数据库工厂
await self._init_vector_db_factory()
# 初始化解析器和分块器
parsers = {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
chunking_config = self.config.get("chunking", {})
chunker = FixedSizeChunker(
chunk_size=chunking_config.get("chunk_size", 512),
chunk_overlap=chunking_config.get("chunk_overlap", 50),
)
# 初始化知识库管理器
files_path = self.config.get("storage", {}).get(
"files_path", "data/knowledge_base"
)
self.kb_manager = KBManager(
db=self.kb_db,
vec_db_factory=self.vec_db_factory,
storage_path=files_path,
parsers=parsers,
chunker=chunker,
provider_manager=self.provider_manager,
)
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_database)
rank_fusion = RankFusion(self.kb_database)
self.retrieval_manager = RetrievalManager(
vec_db_factory=self.vec_db_factory,
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_database,
)
# 初始化上下文注入器
self.kb_injector = KnowledgeBaseInjector(
kb_db=self.kb_database,
vec_db_factory=self.vec_db_factory,
retrieval_manager=self.retrieval_manager,
)
self._initialized = True
logger.info("知识库模块初始化完成")
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
import traceback
logger.error(traceback.format_exc())
async def _init_kb_database(self):
"""初始化知识库独立数据库"""
db_path = self.config.get("storage", {}).get(
"kb_db_path", "data/knowledge_base/kb.db"
)
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
self.kb_db = KBSQLiteDatabase(db_path)
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
self.kb_database = KBDatabase(self.kb_db)
logger.info(f"KnowledgeBase database initialized: {db_path}")
async def _init_vector_db_factory(self):
"""初始化向量数据库工厂"""
storage_path = self.config.get("storage", {}).get(
"vector_db_path", "data/knowledge_base/vectors"
)
Path(storage_path).mkdir(parents=True, exist_ok=True)
self.vec_db_factory = VecDBFactory(storage_base_path=storage_path)
@property
def is_initialized(self) -> bool:
"""检查是否已初始化"""
return self._initialized
def get_kb_manager(self):
"""获取知识库管理器"""
return self.kb_manager if self._initialized else None
def get_kb_injector(self):
"""获取知识库上下文注入器"""
return self.kb_injector if self._initialized else None
async def reinitialize(self):
"""重新初始化知识库模块
用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后)
"""
if self._initialized:
logger.info("知识库模块已初始化,将重新初始化")
await self.terminate()
await self.initialize()
return self._initialized
async def terminate(self):
"""终止知识库模块,清理资源"""
if not self._initialized:
return
logger.info("正在终止知识库模块...")
# 关闭向量数据库工厂(关闭所有向量数据库实例)
if self.vec_db_factory:
try:
await self.vec_db_factory.close_all()
logger.debug("向量数据库工厂已关闭")
except Exception as e:
logger.warning(f"关闭向量数据库工厂时出错: {e}")
# 关闭知识库独立数据库连接
if self.kb_db:
try:
await self.kb_db.close()
logger.debug("知识库数据库已关闭")
except Exception as e:
logger.warning(f"关闭知识库数据库时出错: {e}")
self._initialized = False
logger.info("知识库模块已终止")
+275
View File
@@ -0,0 +1,275 @@
import traceback
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_db_sqlite import KBSQLiteDatabase
from .parsers.text_parser import TextParser
from .parsers.pdf_parser import PDFParser
from .chunking.fixed_size import FixedSizeChunker
from .kb_helper import KBHelper
from .models import KnowledgeBase
FILES_PATH = "data/knowledge_base"
DB_PATH = Path(FILES_PATH) / "kb.db"
"""Knowledge Base storage root directory"""
PARSERS = {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
CHUNKER = FixedSizeChunker()
class KnowledgeBaseManager:
kb_db: KBSQLiteDatabase
retrieval_manager: RetrievalManager
def __init__(
self,
provider_manager: ProviderManager,
):
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
# 初始化数据库
await self._init_kb_database()
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_db)
rank_fusion = RankFusion(self.kb_db)
self.retrieval_manager = RetrievalManager(
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_db,
)
await self.load_kbs()
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=record,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
parsers=PARSERS,
)
await kb_helper.initialize()
self.kb_insts[record.kb_id] = kb_helper
async def create_kb(
self,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper:
"""创建新的知识库实例"""
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
)
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=kb,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
parsers=PARSERS,
)
await kb_helper.initialize()
self.kb_insts[kb.kb_id] = kb_helper
return kb_helper
async def get_kb(self, kb_id: str) -> KBHelper | None:
"""获取知识库实例"""
if kb_id in self.kb_insts:
return self.kb_insts[kb_id]
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
"""通过名称获取知识库实例"""
for kb_helper in self.kb_insts.values():
if kb_helper.kb.kb_name == kb_name:
return kb_helper
return None
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return False
await kb_helper.delete_vec_db()
async with self.kb_db.get_db() as session:
await session.delete(kb_helper.kb)
await session.commit()
self.kb_insts.pop(kb_id, None)
return True
async def list_kbs(self) -> list[KnowledgeBase]:
"""列出所有知识库实例"""
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
return kbs
async def update_kb(
self,
kb_id: str,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper | None:
"""更新知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return None
kb = kb_helper.kb
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
if rerank_provider_id is not None:
kb.rerank_provider_id = rerank_provider_id
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
async def retrieve(
self,
query: str,
kb_names: list[str],
top_m_final: int = 5,
) -> dict | None:
"""从指定知识库中检索相关内容"""
kb_ids = []
kb_id_helper_map = {}
for kb_name in kb_names:
if kb_helper := await self.get_kb_by_name(kb_name):
kb_ids.append(kb_helper.kb.kb_id)
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
if not kb_ids:
return {}
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
kb_id_helper_map=kb_id_helper_map,
top_m_final=top_m_final,
)
if not results:
return None
context_text = self._format_context(results)
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: list[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
-230
View File
@@ -1,230 +0,0 @@
"""
知识库独立 SQLite 数据库
该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。
职责:
- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media)
- 提供数据库连接和会话管理
- 执行数据库迁移和初始化
"""
from contextlib import asynccontextmanager
from pathlib import Path
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
class KBSQLiteDatabase:
"""知识库独立 SQLite 数据库
与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。
特点:
- 数据隔离: 知识库数据不会影响主数据库格式
- 独立备份: 可以单独备份和恢复知识库数据
- 性能隔离: 大量知识库查询不会影响主业务性能
"""
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
# noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表
from astrbot.core.knowledge_base.models import ( # noqa: F401
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
from sqlmodel import SQLModel
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(SQLModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
logger.info(f"知识库数据库已初始化: {self.db_path}")
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建块表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
"ON kb_chunks(chunk_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
"ON kb_chunks(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
"ON kb_chunks(vec_doc_id)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
# 创建会话配置表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
"ON kb_session_config(scope_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
"ON kb_session_config(scope)"
)
)
await session.commit()
logger.info("知识库数据库迁移 v1 完成")
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
-430
View File
@@ -1,430 +0,0 @@
"""知识库管理器
该模块提供知识库的CRUD操作和文档上传处理流程。
"""
import uuid
from pathlib import Path
from typing import Optional
import aiofiles
from sqlalchemy import func, select, update
from .kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.chunking.base import BaseChunker
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
from astrbot.core.knowledge_base.parsers.base import BaseParser
from .vec_db_factory import VecDBFactory
class KBManager:
"""知识库管理器
职责:
- 知识库的 CRUD 操作
- 文档上传与解析
- 文档块生成与存储
- 多媒体资源管理
"""
def __init__(
self,
db: KBSQLiteDatabase,
vec_db_factory: VecDBFactory,
storage_path: str,
parsers: dict[str, BaseParser],
chunker: BaseChunker,
provider_manager=None,
):
self.db = db
self.vec_db_factory = vec_db_factory
self.storage_path = Path(storage_path)
self.media_path = self.storage_path / "media"
self.files_path = self.storage_path / "files"
self.parsers = parsers
self.chunker = chunker
self.provider_manager = provider_manager
# 确保目录存在
self.media_path.mkdir(parents=True, exist_ok=True)
self.files_path.mkdir(parents=True, exist_ok=True)
async def _get_embedding_provider_for_kb(self, kb_id: str):
"""根据知识库配置获取 Embedding Provider
Args:
kb_id: 知识库 ID
Returns:
EmbeddingProvider: Embedding Provider 实例
Raises:
ValueError: 如果找不到合适的 embedding provider
"""
from astrbot.core.knowledge_base.database import KBDatabase
# 获取知识库配置
kb_database = KBDatabase(self.db)
kb = await kb_database.get_kb_by_id(kb_id)
if not kb:
raise ValueError(f"知识库不存在: {kb_id}")
embedding_provider_id = kb.embedding_provider_id
# 如果没有 provider_manager,使用默认的第一个
if not self.provider_manager:
raise ValueError("Provider Manager 未初始化")
embedding_providers = self.provider_manager.embedding_provider_insts
if not embedding_providers:
raise ValueError("系统中没有可用的 Embedding Provider")
# 如果指定了 provider ID,则查找该 provider
if embedding_provider_id:
for provider in embedding_providers:
if provider.meta().id == embedding_provider_id:
return provider
raise ValueError(
f"未找到配置的 Embedding Provider: {embedding_provider_id}"
)
# 使用第一个可用的 provider
return embedding_providers[0]
# ===== 知识库操作 =====
async def create_kb(
self,
kb_name: str,
description: Optional[str] = None,
emoji: Optional[str] = None,
embedding_provider_id: Optional[str] = None,
rerank_provider_id: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
top_k_dense: Optional[int] = None,
top_k_sparse: Optional[int] = None,
top_m_final: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KnowledgeBase:
"""创建知识库
Args:
enable_rerank: 是否启用重排序。
- 如果明确传入 True/False,则使用该值
- 如果为 None,则根据是否有可用的 rerank provider 自动决定
"""
# 智能决定 enable_rerank 的默认值
if enable_rerank is None:
# 检查是否有可用的 rerank provider
has_rerank_provider = (
self.provider_manager
and hasattr(self.provider_manager, "rerank_provider_insts")
and len(self.provider_manager.rerank_provider_insts) > 0
)
enable_rerank = has_rerank_provider
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
enable_rerank=enable_rerank,
)
async with self.db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
"""获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def update_kb(
self,
kb_id: str,
kb_name: Optional[str] = None,
description: Optional[str] = None,
emoji: Optional[str] = None,
embedding_provider_id: Optional[str] = None,
rerank_provider_id: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
top_k_dense: Optional[int] = None,
top_k_sparse: Optional[int] = None,
top_m_final: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> Optional[KnowledgeBase]:
"""更新知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if not kb:
return None
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
if rerank_provider_id is not None:
kb.rerank_provider_id = rerank_provider_id
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
if enable_rerank is not None:
kb.enable_rerank = enable_rerank
await session.commit()
await session.refresh(kb)
return kb
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库(级联删除所有文档和资源)"""
# 1. 获取所有文档
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
ops = KBManagerOps(self)
docs = await ops.list_documents(kb_id)
# 2. 删除所有文档(包括文件和向量)
for doc in docs:
await ops.delete_document(doc.doc_id)
# 3. 删除向量数据库
await self.vec_db_factory.delete_vec_db(kb_id)
# 4. 删除知识库记录
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if not kb:
return False
await session.delete(kb)
await session.commit()
return True
# ===== 文档上传 =====
async def upload_document(
self,
kb_id: str,
file_name: str,
file_content: bytes,
file_type: str,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
"""
doc_id = str(uuid.uuid4())
file_path = None
media_paths = []
vec_doc_ids = []
try:
# 1. 保存原始文件
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(file_content)
# 2. 解析文档
parser = self.parsers.get(file_type)
if not parser:
raise ValueError(f"不支持的文件类型: {file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
# 3. 保存多媒体资源
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
ops = KBManagerOps(self)
saved_media = []
for media_item in media_items:
media = await ops._save_media(
kb_id=kb_id,
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 4. 文档分块
chunks_text = await self.chunker.chunk(text_content)
# 5. 获取 Embedding Provider 和向量数据库
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
# 6. 生成向量并存储
saved_chunks = []
for idx, chunk_text in enumerate(chunks_text):
# 存储到向量数据库
vec_doc_id = await vec_db.insert(
content=chunk_text,
metadata={
"kb_id": kb_id,
"doc_id": doc_id,
"chunk_index": idx,
},
)
vec_doc_ids.append(str(vec_doc_id))
# 保存块元数据
chunk = KBChunk(
doc_id=doc_id,
kb_id=kb_id,
chunk_index=idx,
content=chunk_text,
char_count=len(chunk_text),
vec_doc_id=str(vec_doc_id),
)
saved_chunks.append(chunk)
# 7. 保存文档元数据(事务)
doc = KBDocument(
doc_id=doc_id,
kb_id=kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
file_path=str(file_path),
chunk_count=len(saved_chunks),
media_count=len(saved_media),
)
async with self.db.get_db() as session:
async with session.begin():
session.add(doc)
for chunk in saved_chunks:
session.add(chunk)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
# 8. 更新知识库统计
await self._update_kb_stats(kb_id)
return doc
except Exception as e:
# 失败清理:删除已创建的资源
from astrbot.core import logger
logger.error(f"文档上传失败,开始清理资源: {e}")
# 获取知识库的向量数据库
try:
embedding_provider = await self._get_embedding_provider_for_kb(kb_id)
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
# 清理向量数据库
for vec_id in vec_doc_ids:
try:
await vec_db.delete(vec_id)
except Exception as ve:
logger.warning(f"清理向量失败 {vec_id}: {ve}")
except Exception as vfe:
logger.error(f"获取向量数据库失败: {vfe}")
# 清理多媒体文件
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
# 清理文档文件
if file_path and file_path.exists():
try:
file_path.unlink()
except Exception as fe:
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
# 重新抛出原始异常
raise
# ===== 统计更新 =====
async def _update_kb_stats(self, kb_id: str):
"""更新知识库统计信息(事务中执行)"""
async with self.db.get_db() as session:
async with session.begin():
# 统计文档数(在事务中查询)
doc_count = (
await session.scalar(
select(func.count(KBDocument.id)).where(
KBDocument.kb_id == kb_id
)
)
or 0
)
# 统计块数(在事务中查询)
chunk_count = (
await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
)
or 0
)
# 更新知识库(在同一事务中)
await session.execute(
update(KnowledgeBase)
.where(KnowledgeBase.kb_id == kb_id)
.values(doc_count=doc_count, chunk_count=chunk_count)
)
await session.commit()
-323
View File
@@ -1,323 +0,0 @@
"""知识库管理器辅助操作
该模块提供文档、块和多媒体的管理操作。
"""
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import aiofiles
from sqlalchemy import delete, func, select
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
if TYPE_CHECKING:
from astrbot.core.knowledge_base.manager import KBManager
class KBManagerOps:
"""知识库管理器辅助操作类
职责:
- 文档管理操作
- 块管理操作
- 多媒体管理操作
"""
def __init__(self, manager: "KBManager"):
self.manager = manager
self.db = manager.db
self.vec_db_factory = manager.vec_db_factory
self.media_path = manager.media_path
self.files_path = manager.files_path
# ===== 文档操作 =====
async def list_documents(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取文档详情"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def delete_document(self, doc_id: str) -> bool:
"""删除文档(级联删除块、多媒体、向量)
采用三阶段删除策略:
1. 删除向量数据库中的向量(允许部分失败)
2. 删除SQL数据库中的记录(事务保证原子性)
3. 删除文件系统中的文件(失败不影响数据一致性)
"""
from astrbot.core import logger
# 0. 获取文档信息
doc = await self.get_document(doc_id)
if not doc:
return False
# 收集所有需要删除的资源
chunks = await self.list_chunks(doc_id)
media_list = await self.list_media(doc_id)
# 获取知识库的向量数据库
embedding_provider = await self.manager._get_embedding_provider_for_kb(
doc.kb_id
)
vec_db = await self.vec_db_factory.get_vec_db(doc.kb_id, embedding_provider)
# ===== 第一阶段: 删除向量(可重试) =====
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
deleted_vec_ids = []
failed_vec_ids = []
for vec_id in vec_ids_to_delete:
try:
await vec_db.delete(vec_id)
deleted_vec_ids.append(vec_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_id}, {e}")
failed_vec_ids.append(vec_id)
# 如果向量删除失败过多(超过50%),中止操作
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
logger.error(
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
)
return False
# 记录部分失败但继续执行
if failed_vec_ids:
logger.warning(
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
)
# ===== 第二阶段: 删除数据库记录(事务) =====
async with self.db.get_db() as session:
async with session.begin():
# 删除块记录
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
# 删除多媒体记录
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
# 删除文档记录
await session.execute(
delete(KBDocument).where(KBDocument.doc_id == doc_id)
)
await session.commit()
# ===== 第三阶段: 删除文件(失败不影响) =====
# 删除多媒体文件
for media in media_list:
try:
media_path = Path(media.file_path)
if media_path.exists():
media_path.unlink()
except Exception as e:
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
# 删除文档文件
try:
file_path = Path(doc.file_path)
if file_path.exists():
file_path.unlink()
except Exception as e:
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
# ===== 更新统计 =====
await self.manager._update_kb_stats(doc.kb_id)
return True
# ===== 块操作 =====
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def delete_chunk(self, chunk_id: str) -> bool:
"""删除单个块
流程:
1. 查询块信息
2. 删除向量
3. 删除数据库记录
4. 更新文档统计
"""
from astrbot.core import logger
# 1. 查询块信息
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
chunk = result.scalar_one_or_none()
if not chunk:
return False
doc_id = chunk.doc_id
kb_id = chunk.kb_id
vec_doc_id = chunk.vec_doc_id
# 2. 获取知识库的向量数据库并删除向量
try:
embedding_provider = await self.manager._get_embedding_provider_for_kb(
kb_id
)
vec_db = await self.vec_db_factory.get_vec_db(kb_id, embedding_provider)
await vec_db.delete(vec_doc_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
return False
# 3. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
)
await session.commit()
# 4. 更新文档统计
await self._update_doc_stats(doc_id)
return True
# ===== 多媒体操作 =====
async def list_media(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def delete_media(self, media_id: str) -> bool:
"""删除多媒体资源
流程:
1. 查询媒体信息
2. 删除数据库记录
3. 删除文件(失败不影响)
4. 更新文档统计
"""
from astrbot.core import logger
# 1. 查询媒体信息
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
media = result.scalar_one_or_none()
if not media:
return False
doc_id = media.doc_id
file_path_str = media.file_path
# 2. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(
delete(KBMedia).where(KBMedia.media_id == media_id)
)
await session.commit()
# 3. 删除文件(失败不影响)
try:
media_path = Path(file_path_str)
if media_path.exists():
media_path.unlink()
except Exception as e:
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
# 4. 更新文档统计
await self._update_doc_stats(doc_id)
return True
# ===== 内部辅助方法 =====
async def _save_media(
self,
kb_id: str,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
# 创建记录
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media
async def _update_doc_stats(self, doc_id: str):
"""更新文档统计信息(事务中执行)"""
async with self.db.get_db() as session:
async with session.begin():
# 统计块数
chunk_count = (
await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
)
) or 0
# 统计多媒体数
media_count = (
await session.scalar(
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
)
) or 0
# 更新文档
doc = await session.scalar(
select(KBDocument).where(KBDocument.doc_id == doc_id)
)
if doc:
doc.chunk_count = chunk_count
doc.media_count = media_count
await session.commit()
+8 -16
View File
@@ -1,22 +1,8 @@
"""知识库管理功能的数据模型定义
该模块定义了知识库系统所需的数据模型,包括:
- KnowledgeBase: 知识库表 (存储在独立的 kb.db)
- KBDocument: 文档表 (存储在独立的 kb.db)
- KBChunk: 文档块表 (存储在独立的 kb.db)
- KBMedia: 多媒体资源表 (存储在独立的 kb.db)
- KBSessionConfig: 会话配置表 (存储在独立的 kb.db)
注意:
- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db)
- 与主数据库 (astrbot.db) 完全解耦
"""
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import Field, SQLModel, Text
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
class KnowledgeBase(SQLModel, table=True):
@@ -49,7 +35,6 @@ class KnowledgeBase(SQLModel, table=True):
top_k_dense: Optional[int] = Field(default=50, nullable=True)
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
top_m_final: Optional[int] = Field(default=5, nullable=True)
enable_rerank: Optional[bool] = Field(default=False, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
@@ -58,6 +43,13 @@ class KnowledgeBase(SQLModel, table=True):
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"kb_name",
name="uix_kb_name",
),
)
class KBDocument(SQLModel, table=True):
"""文档表
@@ -51,10 +51,10 @@ class PDFParser(BaseParser):
continue
resources = page["/Resources"]
if not resources or "/XObject" not in resources:
if not resources or "/XObject" not in resources: # type: ignore
continue
xobjects = resources["/XObject"].get_object()
xobjects = resources["/XObject"].get_object() # type: ignore
if not xobjects:
continue
@@ -6,12 +6,15 @@
from dataclasses import dataclass
from typing import List
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.db.vec_db.base import BaseVecDB, Result
from ..vec_db_factory import VecDBFactory
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from ..kb_helper import KBHelper
from astrbot import logger
@dataclass
class RetrievalResult:
@@ -37,10 +40,9 @@ class RetrievalManager:
def __init__(
self,
vec_db_factory, # VecDBFactory
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBDatabase,
kb_db: KBSQLiteDatabase,
):
"""初始化检索管理器
@@ -50,21 +52,16 @@ class RetrievalManager:
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
"""
self.vec_db_factory = vec_db_factory
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
async def retrieve(
self,
vec_db_factory: VecDBFactory,
query: str,
kb_ids: List[str],
top_k_dense: int = 50,
top_k_sparse: int = 50,
top_n_fusion: int = 20,
kb_id_helper_map: dict[str, KBHelper],
top_m_final: int = 5,
rerank_provider: RerankProvider | None = None,
) -> List[RetrievalResult]:
"""混合检索
@@ -77,35 +74,52 @@ class RetrievalManager:
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k_dense: 稠密检索返回数量
top_k_sparse: 稀疏检索返回数量
top_n_fusion: 融合后返回数量
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
kb_options: dict = {}
new_kb_ids = []
for kb_id in kb_ids:
kb_helper = kb_id_helper_map.get(kb_id)
if kb_helper:
kb = kb_helper.kb
kb_options[kb_id] = {
"top_k_dense": kb.top_k_dense or 50,
"top_k_sparse": kb.top_k_sparse or 50,
"top_m_final": kb.top_m_final or 5,
"vec_db": kb_helper.vec_db,
}
new_kb_ids.append(kb_id)
else:
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
kb_ids = new_kb_ids
# 1. 稠密检索
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
top_k=top_k_dense,
vec_db=vec_db,
kb_options=kb_options,
)
# 2. 稀疏检索
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
top_k=top_k_sparse,
kb_options=kb_options,
)
# 3. 结果融合
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_n_fusion,
top_k=kb_options.get("top_k_fusion", 20),
)
# 4. 转换为 RetrievalResult (获取元数据)
@@ -130,24 +144,27 @@ class RetrievalManager:
)
# 5. Rerank
if rerank_provider and retrieval_results:
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
if vec_db and vec_db.rerank_provider:
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=rerank_provider,
rerank_provider=first_rerank,
)
else:
retrieval_results = retrieval_results[:top_m_final]
return retrieval_results
return retrieval_results[:top_m_final]
async def _dense_retrieve(
self,
query: str,
kb_ids: List[str],
top_k: int,
vec_db: BaseVecDB,
kb_options: dict,
):
"""稠密检索 (向量相似度)
@@ -162,13 +179,17 @@ class RetrievalManager:
List[Result]: 检索结果列表
"""
all_results: list[Result] = []
for kb_id in kb_ids:
if kb_id not in kb_options:
continue
try:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
dense_k = int(kb_options[kb_id]["top_k_dense"])
vec_results = await vec_db.retrieve(
query=query,
top_k=top_k,
fetch_k=top_k * 2,
k=dense_k,
fetch_k=dense_k * 2,
rerank=False, # 稠密检索阶段不进行 rerank
metadata_filters={"kb_id": kb_id},
)
@@ -181,7 +202,7 @@ class RetrievalManager:
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
return all_results[:top_k]
return all_results[: len(all_results) // len(kb_ids)]
async def _rerank(
self,
@@ -7,7 +7,7 @@ from dataclasses import dataclass
from typing import Dict, List
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
@@ -30,7 +30,7 @@ class RankFusion:
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBDatabase, k: int = 60):
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
"""初始化结果融合器
Args:
@@ -8,7 +8,7 @@ from typing import List
from rank_bm25 import BM25Okapi
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
@dataclass
@@ -30,7 +30,7 @@ class SparseRetriever:
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBDatabase):
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化稀疏检索器
Args:
@@ -43,14 +43,14 @@ class SparseRetriever:
self,
query: str,
kb_ids: List[str],
top_k: int = 50,
kb_options: dict,
) -> List[SparseResult]:
"""执行稀疏检索
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
kb_options: 每个知识库的检索选项
Returns:
List[SparseResult]: 检索结果列表
@@ -87,4 +87,4 @@ class SparseRetriever:
)
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
return results[: len(results) // len(kb_ids)]
@@ -1,161 +0,0 @@
"""向量数据库工厂
负责为每个知识库创建和管理独立的向量数据库实例。
架构说明:
- 每个知识库拥有独立的向量数据库实例
- 向量数据库文件以 kb_id 命名
- 工厂类负责实例的创建、缓存和生命周期管理
"""
from pathlib import Path
from typing import Dict, Optional
from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.provider import EmbeddingProvider
class VecDBFactory:
"""向量数据库工厂
职责:
- 为每个知识库创建独立的向量数据库实例
- 缓存已创建的实例以提高性能
- 管理向量数据库的生命周期
"""
def __init__(
self,
storage_base_path: str,
):
"""初始化向量数据库工厂
Args:
storage_base_path: 向量数据库存储基础路径
"""
self.storage_base_path = Path(storage_base_path)
self._instances: Dict[str, BaseVecDB] = {}
# 确保基础路径存在
self.storage_base_path.mkdir(parents=True, exist_ok=True)
async def get_vec_db(
self, kb_id: str, embedding_provider: EmbeddingProvider
) -> BaseVecDB:
"""获取或创建指定知识库的向量数据库实例
Args:
kb_id: 知识库 ID
embedding_provider: Embedding Provider 实例
Returns:
BaseVecDB: 向量数据库实例
"""
# 如果已经创建过,直接返回缓存的实例
if kb_id in self._instances:
return self._instances[kb_id]
# 创建新实例
vec_db = await self._create_vec_db(kb_id, embedding_provider)
self._instances[kb_id] = vec_db
logger.debug(f"创建知识库 {kb_id} 的向量数据库实例")
return vec_db
async def _create_vec_db(
self, kb_id: str, embedding_provider: EmbeddingProvider
) -> BaseVecDB:
"""创建向量数据库实例
Args:
kb_id: 知识库 ID
embedding_provider: Embedding Provider 实例
Returns:
BaseVecDB: 向量数据库实例
"""
# 为每个知识库创建独立的存储路径
kb_storage_path = self.storage_base_path / kb_id
kb_storage_path.mkdir(parents=True, exist_ok=True)
doc_store_path = str(kb_storage_path / "documents.db")
index_store_path = str(kb_storage_path / "index.faiss")
vec_db = FaissVecDB(
doc_store_path=doc_store_path,
index_store_path=index_store_path,
embedding_provider=embedding_provider,
)
await vec_db.initialize()
return vec_db
async def delete_vec_db(self, kb_id: str) -> bool:
"""删除指定知识库的向量数据库
Args:
kb_id: 知识库 ID
Returns:
bool: 是否删除成功
"""
# 关闭并移除缓存的实例
if kb_id in self._instances:
try:
await self._instances[kb_id].close()
except Exception as e:
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
del self._instances[kb_id]
# 删除文件系统中的向量数据库文件
kb_storage_path = self.storage_base_path / kb_id
if kb_storage_path.exists():
try:
import shutil
shutil.rmtree(kb_storage_path)
logger.info(f"已删除知识库 {kb_id} 的向量数据库文件")
return True
except Exception as e:
logger.error(f"删除向量数据库文件失败 ({kb_id}): {e}")
return False
return True
async def close_all(self):
"""关闭所有向量数据库实例"""
for kb_id, vec_db in list(self._instances.items()):
try:
await vec_db.close()
logger.debug(f"已关闭知识库 {kb_id} 的向量数据库")
except Exception as e:
logger.warning(f"关闭向量数据库失败 ({kb_id}): {e}")
self._instances.clear()
def has_instance(self, kb_id: str) -> bool:
"""检查是否已创建指定知识库的向量数据库实例
Args:
kb_id: 知识库 ID
Returns:
bool: 是否已创建实例
"""
return kb_id in self._instances
def get_cached_instance(self, kb_id: str) -> Optional[BaseVecDB]:
"""获取已缓存的向量数据库实例(不创建新实例)
Args:
kb_id: 知识库 ID
Returns:
Optional[BaseVecDB]: 向量数据库实例,如果不存在则返回 None
"""
return self._instances.get(kb_id)
+7 -5
View File
@@ -15,13 +15,15 @@ async def inject_kb_context(
p_ctx: Pipeline context
req: Provider request
"""
kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector()
if not kb_injector:
kb_mgr = p_ctx.plugin_manager.context.kb_manager
kb_names = p_ctx.astrbot_config.get("kb_names", [])
if not kb_names:
return
kb_context = await kb_injector.retrieve_and_inject(
unified_msg_origin=umo,
kb_context = await kb_mgr.retrieve(
query=req.prompt,
top_k=top_k,
kb_names=kb_names,
)
if not kb_context:
return
+1 -1
View File
@@ -19,7 +19,7 @@ from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
File diff suppressed because it is too large Load Diff
@@ -21,7 +21,8 @@
"delete": "Delete",
"preview": "Preview",
"search": "Search Chunks",
"searchPlaceholder": "Enter keywords to search chunks..."
"searchPlaceholder": "Enter keywords to search chunks...",
"showing": "Showing"
},
"edit": {
"title": "Edit Chunk",
@@ -21,7 +21,8 @@
"delete": "删除",
"preview": "预览",
"search": "搜索分块",
"searchPlaceholder": "输入关键词搜索分块内容..."
"searchPlaceholder": "输入关键词搜索分块内容...",
"showing": "显示"
},
"edit": {
"title": "编辑分块",
@@ -82,7 +82,7 @@
<v-card-title class="d-flex align-center pa-4">
<span>{{ t('chunks.title') }}</span>
<v-chip class="ml-2" size="small" variant="tonal">
{{ chunks.length }} {{ t('chunks.title') }}
{{ totalChunks }} {{ t('chunks.title') }}
</v-chip>
<v-spacer />
<v-text-field
@@ -104,7 +104,8 @@
:headers="headers"
:items="filteredChunks"
:loading="loadingChunks"
:items-per-page="10"
:items-per-page="pageSize"
hide-default-footer
>
<template #item.chunk_index="{ item }">
<v-chip size="small" variant="tonal" color="primary">
@@ -132,20 +133,6 @@
color="info"
@click="viewChunk(item)"
/>
<v-btn
icon="mdi-pencil"
variant="text"
size="small"
color="primary"
@click="editChunk(item)"
/>
<v-btn
icon="mdi-delete"
variant="text"
size="small"
color="error"
@click="confirmDeleteChunk(item)"
/>
</template>
<template #no-data>
@@ -155,6 +142,31 @@
</div>
</template>
</v-data-table>
<!-- 自定义分页器 -->
<div v-if="!searchQuery && totalChunks > 0" class="pa-4 d-flex align-center justify-space-between">
<div class="text-caption text-medium-emphasis">
{{ t('chunks.showing') }} {{ (page - 1) * pageSize + 1 }} - {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }}
</div>
<div class="d-flex align-center gap-2">
<v-select
v-model="pageSize"
:items="[10, 25, 50, 100]"
density="compact"
variant="outlined"
hide-details
style="width: 100px"
@update:model-value="handlePageSizeChange"
/>
<v-pagination
v-model="page"
:length="Math.ceil(totalChunks / pageSize)"
:total-visible="5"
@update:model-value="handlePageChange"
/>
</div>
</div>
</v-card-text>
</v-card>
</div>
@@ -212,71 +224,6 @@
</v-card>
</v-dialog>
<!-- 编辑分块对话框 -->
<v-dialog v-model="showEditDialog" max-width="800px" persistent scrollable>
<v-card>
<v-card-title class="pa-4">
<span>{{ t('edit.title') }}</span>
<v-spacer />
<v-btn icon="mdi-close" variant="text" @click="closeEditDialog" />
</v-card-title>
<v-divider />
<v-card-text class="pa-6">
<v-textarea
v-model="editForm.content"
:label="t('edit.content')"
variant="outlined"
rows="15"
auto-grow
/>
</v-card-text>
<v-divider />
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="closeEditDialog">
{{ t('edit.cancel') }}
</v-btn>
<v-btn
color="primary"
variant="elevated"
@click="saveChunk"
:loading="saving"
>
{{ t('edit.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 删除确认对话框 -->
<v-dialog v-model="showDeleteDialog" max-width="450px">
<v-card>
<v-card-title class="pa-4 text-h6">{{ t('delete.title') }}</v-card-title>
<v-divider />
<v-card-text class="pa-6">
<p>{{ t('delete.confirmText') }}</p>
<v-alert type="warning" variant="tonal" density="compact" class="mt-4">
{{ t('delete.warning') }}
</v-alert>
</v-card-text>
<v-divider />
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="showDeleteDialog = false">
{{ t('delete.cancel') }}
</v-btn>
<v-btn
color="error"
variant="elevated"
@click="deleteChunk"
:loading="deleting"
>
{{ t('delete.confirm') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 消息提示 -->
<v-snackbar v-model="snackbar.show" :color="snackbar.color">
{{ snackbar.text }}
@@ -299,16 +246,16 @@ const docId = ref(route.params.docId as string)
// 状态
const loading = ref(true)
const loadingChunks = ref(false)
const saving = ref(false)
const deleting = ref(false)
const document = ref<any>({})
const chunks = ref<any[]>([])
const searchQuery = ref('')
const showViewDialog = ref(false)
const showEditDialog = ref(false)
const showDeleteDialog = ref(false)
const selectedChunk = ref<any>(null)
const deleteTarget = ref<any>(null)
// 分页状态
const page = ref(1)
const pageSize = ref(10)
const totalChunks = ref(0)
const snackbar = ref({
show: false,
@@ -322,17 +269,12 @@ const showSnackbar = (text: string, color: string = 'success') => {
snackbar.value.show = true
}
// 编辑表单
const editForm = ref({
content: ''
})
// 表格列
const headers = [
{ title: t('chunks.index'), key: 'chunk_index', width: 100 },
{ title: t('chunks.content'), key: 'content', sortable: false },
{ title: t('chunks.charCount'), key: 'char_count', width: 150 },
{ title: t('chunks.actions'), key: 'actions', sortable: false, align: 'end', width: 150 }
{ title: t('chunks.actions'), key: 'actions', sortable: false, width: 150 }
]
// 过滤分块
@@ -349,7 +291,7 @@ const loadDocument = async () => {
loading.value = true
try {
const response = await axios.get('/api/kb/document/get', {
params: { doc_id: docId.value }
params: { doc_id: docId.value, kb_id: kbId.value }
})
if (response.data.status === 'ok') {
document.value = response.data.data
@@ -367,10 +309,16 @@ const loadChunks = async () => {
loadingChunks.value = true
try {
const response = await axios.get('/api/kb/chunk/list', {
params: { doc_id: docId.value }
params: {
doc_id: docId.value,
kb_id: kbId.value,
page: page.value,
page_size: pageSize.value
}
})
if (response.data.status === 'ok') {
chunks.value = response.data.data.items || []
totalChunks.value = response.data.data.items.length
}
} catch (error) {
console.error('Failed to load chunks:', error)
@@ -380,84 +328,24 @@ const loadChunks = async () => {
}
}
// 处理分页变化
const handlePageChange = (newPage: number) => {
page.value = newPage
loadChunks()
}
const handlePageSizeChange = (newPageSize: number) => {
pageSize.value = newPageSize
page.value = 1
loadChunks()
}
// 查看分块
const viewChunk = (chunk: any) => {
selectedChunk.value = chunk
showViewDialog.value = true
}
// 编辑分块
const editChunk = (chunk: any) => {
selectedChunk.value = chunk
editForm.value.content = chunk.content
showEditDialog.value = true
}
// 关闭编辑对话框
const closeEditDialog = () => {
showEditDialog.value = false
selectedChunk.value = null
editForm.value.content = ''
}
// 保存分块
const saveChunk = async () => {
if (!selectedChunk.value) return
saving.value = true
try {
const response = await axios.post('/api/kb/chunk/update', {
chunk_id: selectedChunk.value.chunk_id,
content: editForm.value.content
})
if (response.data.status === 'ok') {
showSnackbar(t('edit.saveSuccess'))
closeEditDialog()
await loadChunks()
} else {
showSnackbar(response.data.message || t('edit.saveFailed'), 'error')
}
} catch (error) {
console.error('Failed to save chunk:', error)
showSnackbar(t('edit.saveFailed'), 'error')
} finally {
saving.value = false
}
}
// 确认删除分块
const confirmDeleteChunk = (chunk: any) => {
deleteTarget.value = chunk
showDeleteDialog.value = true
}
// 删除分块
const deleteChunk = async () => {
if (!deleteTarget.value) return
deleting.value = true
try {
const response = await axios.post('/api/kb/chunk/delete', {
chunk_id: deleteTarget.value.chunk_id
})
if (response.data.status === 'ok') {
showSnackbar(t('delete.deleteSuccess'))
showDeleteDialog.value = false
await loadChunks()
await loadDocument()
} else {
showSnackbar(response.data.message || t('delete.deleteFailed'), 'error')
}
} catch (error) {
console.error('Failed to delete chunk:', error)
showSnackbar(t('delete.deleteFailed'), 'error')
} finally {
deleting.value = false
}
}
// 工具函数
const getFileIcon = (fileType: string) => {
const type = fileType?.toLowerCase() || ''
@@ -582,6 +470,10 @@ onMounted(() => {
font-family: 'Consolas', 'Monaco', monospace;
}
.gap-2 {
gap: 8px;
}
/* 响应式设计 */
@media (max-width: 768px) {
.document-detail-page {
@@ -51,7 +51,7 @@
</v-tabs>
<!-- 标签页内容 -->
<v-window v-model="activeTab">
<v-window v-model="activeTab" style="padding: 8px;">
<!-- 概览 -->
<v-window-item value="overview">
<v-row>
@@ -163,7 +163,7 @@
<!-- 知识库检索 -->
<v-window-item value="retrieval">
<RetrievalTab :kb-id="kbId" />
<RetrievalTab :kb-id="kbId" :kb-name="kb.kb_name"/>
</v-window-item>
<!-- 使用会话 -->
@@ -329,12 +329,11 @@ onMounted(() => {
padding: 24px;
text-align: center;
border-radius: 12px;
background: rgba(var(--v-theme-surface-variant), 0.3);
background: rgba(var(--v-theme-surface-variant), 0.1);
transition: all 0.3s ease;
}
.stat-box:hover {
transform: translateY(-4px);
background: rgba(var(--v-theme-surface-variant), 0.5);
}
@@ -342,12 +341,10 @@ onMounted(() => {
font-size: 2rem;
font-weight: 600;
margin-top: 8px;
color: rgb(var(--v-theme-on-surface));
}
.stat-label {
font-size: 0.875rem;
color: rgb(var(--v-theme-on-surface-variant));
margin-top: 4px;
}
@@ -7,10 +7,16 @@
</v-card-subtitle>
<v-divider />
<v-progress-linear
v-if="loading"
indeterminate
color="primary"
height="2"
/>
<v-card-text class="pa-6">
<!-- 查询输入区域 -->
<v-row>
<v-row class="mb-4">
<v-col cols="12" md="8">
<v-textarea
v-model="query"
@@ -34,20 +40,7 @@
variant="outlined"
density="compact"
persistent-hint
class="mb-3"
/>
<v-checkbox
v-model="enableRerank"
:label="t('retrieval.enableRerank')"
:hint="t('retrieval.enableRerankHint')"
color="primary"
density="compact"
persistent-hint
/>
<v-alert v-if="enableRerank" type="info" variant="tonal" class="mt-2" density="compact">
如果没有配置重排序模型提供商将跳过重排序步骤
</v-alert>
</v-card>
</v-col>
</v-row>
@@ -143,7 +136,8 @@ import { useModuleI18n } from '@/i18n/composables'
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
const props = defineProps<{
kbId: string
kbId: string,
kbName: string,
}>()
// 状态
@@ -179,7 +173,7 @@ const performRetrieval = async () => {
try {
const response = await axios.post('/api/kb/retrieve', {
query: query.value,
kb_ids: [props.kbId],
kb_names: [props.kbName],
top_k: topK.value,
enable_rerank: enableRerank.value
})