From a0254ed817a293c8f7687a2c93a7b5aadf2a1b88 Mon Sep 17 00:00:00 2001 From: lxfight <1686540385@qq.com> Date: Sun, 19 Oct 2025 19:36:26 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E7=AE=A1=E7=90=86=E5=99=A8=E5=92=8C=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E7=9A=84=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../knowledge_base/kb_manager_lifecycle.py | 4 +++- astrbot/core/knowledge_base/kb_sqlite.py | 9 ++++----- astrbot/core/knowledge_base/manager.py | 20 +++++++++++++------ astrbot/core/knowledge_base/manager_ops.py | 12 ++++++++--- astrbot/core/knowledge_base/models.py | 4 +--- .../knowledge_base/retrieval/rank_fusion.py | 5 +++-- 6 files changed, 34 insertions(+), 20 deletions(-) diff --git a/astrbot/core/knowledge_base/kb_manager_lifecycle.py b/astrbot/core/knowledge_base/kb_manager_lifecycle.py index 51830769..161872b4 100644 --- a/astrbot/core/knowledge_base/kb_manager_lifecycle.py +++ b/astrbot/core/knowledge_base/kb_manager_lifecycle.py @@ -302,7 +302,9 @@ class KnowledgeBaseManager: async def on_session_deleted(session_id: str): """会话删除回调:清理知识库配置""" try: - await self.kb_database.delete_session_kb_config_by_session_id(session_id) + await self.kb_database.delete_session_kb_config_by_session_id( + session_id + ) logger.info(f"已清理会话知识库配置: {session_id}") except Exception as e: logger.error(f"清理会话知识库配置失败 ({session_id}): {e}") diff --git a/astrbot/core/knowledge_base/kb_sqlite.py b/astrbot/core/knowledge_base/kb_sqlite.py index c42d2b4b..526b6277 100644 --- a/astrbot/core/knowledge_base/kb_sqlite.py +++ b/astrbot/core/knowledge_base/kb_sqlite.py @@ -70,7 +70,8 @@ class KBSQLiteDatabase: async def initialize(self) -> None: """初始化数据库,创建表并配置 SQLite 参数""" - from astrbot.core.knowledge_base.models import ( + # noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表 + from astrbot.core.knowledge_base.models import ( # noqa: F401 KBChunk, KBDocument, KBMedia, @@ -170,8 +171,7 @@ class KBSQLiteDatabase: ) await session.execute( text( - "CREATE INDEX IF NOT EXISTS idx_chunk_kb_id " - "ON kb_chunks(kb_id)" + "CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)" ) ) await session.execute( @@ -196,8 +196,7 @@ class KBSQLiteDatabase: ) await session.execute( text( - "CREATE INDEX IF NOT EXISTS idx_media_kb_id " - "ON kb_media(kb_id)" + "CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)" ) ) await session.execute( diff --git a/astrbot/core/knowledge_base/manager.py b/astrbot/core/knowledge_base/manager.py index 98462941..497f64ac 100644 --- a/astrbot/core/knowledge_base/manager.py +++ b/astrbot/core/knowledge_base/manager.py @@ -330,14 +330,22 @@ class KBManager: async with self.db.get_db() as session: async with session.begin(): # 统计文档数(在事务中查询) - doc_count = await session.scalar( - select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id) - ) or 0 + doc_count = ( + await session.scalar( + select(func.count(KBDocument.id)).where( + KBDocument.kb_id == kb_id + ) + ) + or 0 + ) # 统计块数(在事务中查询) - chunk_count = await session.scalar( - select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id) - ) or 0 + chunk_count = ( + await session.scalar( + select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id) + ) + or 0 + ) # 更新知识库(在同一事务中) await session.execute( diff --git a/astrbot/core/knowledge_base/manager_ops.py b/astrbot/core/knowledge_base/manager_ops.py index 521d3de5..e0ab5f6d 100644 --- a/astrbot/core/knowledge_base/manager_ops.py +++ b/astrbot/core/knowledge_base/manager_ops.py @@ -111,7 +111,9 @@ class KBManagerOps: await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id)) # 删除文档记录 - await session.execute(delete(KBDocument).where(KBDocument.doc_id == doc_id)) + await session.execute( + delete(KBDocument).where(KBDocument.doc_id == doc_id) + ) await session.commit() @@ -183,7 +185,9 @@ class KBManagerOps: # 3. 删除数据库记录 async with self.db.get_db() as session: async with session.begin(): - await session.execute(delete(KBChunk).where(KBChunk.chunk_id == chunk_id)) + await session.execute( + delete(KBChunk).where(KBChunk.chunk_id == chunk_id) + ) await session.commit() # 4. 更新文档统计 @@ -225,7 +229,9 @@ class KBManagerOps: # 2. 删除数据库记录 async with self.db.get_db() as session: async with session.begin(): - await session.execute(delete(KBMedia).where(KBMedia.media_id == media_id)) + await session.execute( + delete(KBMedia).where(KBMedia.media_id == media_id) + ) await session.commit() # 3. 删除文件(失败不影响) diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index 44e51928..28adbaa0 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -179,6 +179,4 @@ class KBSessionConfig(SQLModel, table=True): sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) - __table_args__ = ( - UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"), - ) + __table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),) diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 0dd483c1..b05fe1be 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -60,7 +60,9 @@ class RankFusion: List[FusedResult]: 融合后的结果列表 """ # 1. 构建排名映射 - dense_ranks = {r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)} + dense_ranks = { + r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results) + } sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)} # 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id) @@ -118,7 +120,6 @@ class RankFusion: ) elif identifier in vec_doc_id_to_dense: # 从向量检索获取信息,需要从数据库获取块的详细信息 - dr = vec_doc_id_to_dense[identifier] chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier) if chunk: fused_results.append(