refactor: 优化知识库管理器和数据库操作的代码格式

This commit is contained in:
lxfight
2025-10-19 19:36:26 +08:00
parent 2563ecf3c5
commit a0254ed817
6 changed files with 34 additions and 20 deletions
@@ -302,7 +302,9 @@ class KnowledgeBaseManager:
async def on_session_deleted(session_id: str):
"""会话删除回调:清理知识库配置"""
try:
await self.kb_database.delete_session_kb_config_by_session_id(session_id)
await self.kb_database.delete_session_kb_config_by_session_id(
session_id
)
logger.info(f"已清理会话知识库配置: {session_id}")
except Exception as e:
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
+4 -5
View File
@@ -70,7 +70,8 @@ class KBSQLiteDatabase:
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
from astrbot.core.knowledge_base.models import (
# noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表
from astrbot.core.knowledge_base.models import ( # noqa: F401
KBChunk,
KBDocument,
KBMedia,
@@ -170,8 +171,7 @@ class KBSQLiteDatabase:
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id "
"ON kb_chunks(kb_id)"
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
)
)
await session.execute(
@@ -196,8 +196,7 @@ class KBSQLiteDatabase:
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id "
"ON kb_media(kb_id)"
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
+14 -6
View File
@@ -330,14 +330,22 @@ class KBManager:
async with self.db.get_db() as session:
async with session.begin():
# 统计文档数(在事务中查询)
doc_count = await session.scalar(
select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
) or 0
doc_count = (
await session.scalar(
select(func.count(KBDocument.id)).where(
KBDocument.kb_id == kb_id
)
)
or 0
)
# 统计块数(在事务中查询)
chunk_count = await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
) or 0
chunk_count = (
await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
)
or 0
)
# 更新知识库(在同一事务中)
await session.execute(
+9 -3
View File
@@ -111,7 +111,9 @@ class KBManagerOps:
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
# 删除文档记录
await session.execute(delete(KBDocument).where(KBDocument.doc_id == doc_id))
await session.execute(
delete(KBDocument).where(KBDocument.doc_id == doc_id)
)
await session.commit()
@@ -183,7 +185,9 @@ class KBManagerOps:
# 3. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(delete(KBChunk).where(KBChunk.chunk_id == chunk_id))
await session.execute(
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
)
await session.commit()
# 4. 更新文档统计
@@ -225,7 +229,9 @@ class KBManagerOps:
# 2. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(delete(KBMedia).where(KBMedia.media_id == media_id))
await session.execute(
delete(KBMedia).where(KBMedia.media_id == media_id)
)
await session.commit()
# 3. 删除文件(失败不影响)
+1 -3
View File
@@ -179,6 +179,4 @@ class KBSessionConfig(SQLModel, table=True):
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),
)
__table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),)
@@ -60,7 +60,9 @@ class RankFusion:
List[FusedResult]: 融合后的结果列表
"""
# 1. 构建排名映射
dense_ranks = {r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)}
dense_ranks = {
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
}
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
# 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id)
@@ -118,7 +120,6 @@ class RankFusion:
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
dr = vec_doc_id_to_dense[identifier]
chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier)
if chunk:
fused_results.append(