refactor: 优化知识库管理器和数据库操作的代码格式
This commit is contained in:
@@ -302,7 +302,9 @@ class KnowledgeBaseManager:
|
||||
async def on_session_deleted(session_id: str):
|
||||
"""会话删除回调:清理知识库配置"""
|
||||
try:
|
||||
await self.kb_database.delete_session_kb_config_by_session_id(session_id)
|
||||
await self.kb_database.delete_session_kb_config_by_session_id(
|
||||
session_id
|
||||
)
|
||||
logger.info(f"已清理会话知识库配置: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
|
||||
|
||||
@@ -70,7 +70,8 @@ class KBSQLiteDatabase:
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
# noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表
|
||||
from astrbot.core.knowledge_base.models import ( # noqa: F401
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
@@ -170,8 +171,7 @@ class KBSQLiteDatabase:
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id "
|
||||
"ON kb_chunks(kb_id)"
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
@@ -196,8 +196,7 @@ class KBSQLiteDatabase:
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id "
|
||||
"ON kb_media(kb_id)"
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
|
||||
@@ -330,14 +330,22 @@ class KBManager:
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计文档数(在事务中查询)
|
||||
doc_count = await session.scalar(
|
||||
select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
|
||||
) or 0
|
||||
doc_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBDocument.id)).where(
|
||||
KBDocument.kb_id == kb_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计块数(在事务中查询)
|
||||
chunk_count = await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
|
||||
) or 0
|
||||
chunk_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# 更新知识库(在同一事务中)
|
||||
await session.execute(
|
||||
|
||||
@@ -111,7 +111,9 @@ class KBManagerOps:
|
||||
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
|
||||
|
||||
# 删除文档记录
|
||||
await session.execute(delete(KBDocument).where(KBDocument.doc_id == doc_id))
|
||||
await session.execute(
|
||||
delete(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@@ -183,7 +185,9 @@ class KBManagerOps:
|
||||
# 3. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(delete(KBChunk).where(KBChunk.chunk_id == chunk_id))
|
||||
await session.execute(
|
||||
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 4. 更新文档统计
|
||||
@@ -225,7 +229,9 @@ class KBManagerOps:
|
||||
# 2. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(delete(KBMedia).where(KBMedia.media_id == media_id))
|
||||
await session.execute(
|
||||
delete(KBMedia).where(KBMedia.media_id == media_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 3. 删除文件(失败不影响)
|
||||
|
||||
@@ -179,6 +179,4 @@ class KBSessionConfig(SQLModel, table=True):
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),)
|
||||
|
||||
@@ -60,7 +60,9 @@ class RankFusion:
|
||||
List[FusedResult]: 融合后的结果列表
|
||||
"""
|
||||
# 1. 构建排名映射
|
||||
dense_ranks = {r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)}
|
||||
dense_ranks = {
|
||||
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
|
||||
}
|
||||
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
|
||||
|
||||
# 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id)
|
||||
@@ -118,7 +120,6 @@ class RankFusion:
|
||||
)
|
||||
elif identifier in vec_doc_id_to_dense:
|
||||
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
||||
dr = vec_doc_id_to_dense[identifier]
|
||||
chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier)
|
||||
if chunk:
|
||||
fused_results.append(
|
||||
|
||||
Reference in New Issue
Block a user