feat: 集成知识库到核心生命周期和消息流水线

- 在 AstrBotCoreLifecycle 中初始化知识库管理器
- 将知识库注入器添加到消息处理上下文
- 在消息流水线中添加 KBEnhanceStage(知识库增强阶段)
- 实现会话删除时的知识库配置级联清理机制
- 添加会话管理器的回调注册机制,支持零侵入扩展
This commit is contained in:
lxfight
2025-10-19 18:41:34 +08:00
parent ad96d676e6
commit 98a75e923d
4 changed files with 127 additions and 1 deletions

View File

@@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
import json
from astrbot.core import sp
from typing import Dict, List
from typing import Dict, List, Callable, Awaitable
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
@@ -20,6 +20,38 @@ class ConversationManager:
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
def register_on_session_deleted(
self, callback: Callable[[str], Awaitable[None]]
) -> None:
"""注册会话删除回调函数
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
await callback(unified_msg_origin)
except Exception as e:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
@@ -106,6 +138,9 @@ class ConversationManager:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
# 触发会话删除回调(级联清理)
await self._trigger_session_deleted(unified_msg_origin)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID

View File

@@ -34,6 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
class AstrBotCoreLifecycle:
@@ -132,6 +133,19 @@ class AstrBotCoreLifecycle:
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.db, self.provider_manager
)
await self.kb_manager.initialize()
# 将知识库注入器添加到 star_context 中,供 Pipeline 使用
self.star_context.kb_injector = self.kb_manager.get_kb_injector()
# 注册知识库会话生命周期钩子(零侵入级联清理)
if self.kb_manager.is_initialized:
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
@@ -233,6 +247,7 @@ class AstrBotCoreLifecycle:
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
@@ -248,6 +263,7 @@ class AstrBotCoreLifecycle:
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True

View File

@@ -5,6 +5,7 @@ from astrbot.core.message.message_event_result import (
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .kb_enhance.stage import KBEnhanceStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
@@ -21,6 +22,7 @@ STAGES_ORDER = [
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"KBEnhanceStage", # 知识库增强
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
@@ -33,6 +35,7 @@ __all__ = [
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"KBEnhanceStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",

View File

@@ -0,0 +1,72 @@
"""
知识库增强阶段
在 LLM 调用之前,根据会话配置注入知识库上下文
"""
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
@register_stage
class KBEnhanceStage(Stage):
"""知识库增强阶段
功能:
- 检查会话是否配置了知识库
- 如果配置了知识库,则检索相关知识并注入到事件上下文中
- 供后续的 ProcessStage 使用
"""
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.kb_config = self.config.get("knowledge_base", {})
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理知识库上下文注入"""
# 检查知识库功能是否启用
if not self.kb_config.get("enabled", False):
return
# 检查是否需要调用知识库 (只有在被@或唤醒时才检索)
if not event.is_at_or_wake_command:
return
try:
# 从 plugin_manager.context 获取 kb_injector
kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None)
if not kb_injector:
logger.debug("知识库注入器未初始化,跳过知识库增强")
return
# 获取会话 ID
unified_msg_origin = event.unified_msg_origin
# 获取用户查询
query = event.message_str
# 检索并注入知识
kb_context = await kb_injector.retrieve_and_inject(
unified_msg_origin=unified_msg_origin,
query=query,
)
if kb_context:
# 将知识库上下文存储到事件的 extra 中
event.set_extra("kb_context", kb_context)
logger.debug(
f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识"
)
except Exception as e:
logger.error(f"知识库增强阶段处理失败: {e}")
import traceback
logger.error(traceback.format_exc())