feat: 集成知识库到核心生命周期和消息流水线

- 在 AstrBotCoreLifecycle 中初始化知识库管理器
- 将知识库注入器添加到消息处理上下文
- 在消息流水线中添加 KBEnhanceStage(知识库增强阶段)
- 实现会话删除时的知识库配置级联清理机制
- 添加会话管理器的回调注册机制,支持零侵入扩展
This commit is contained in:
lxfight
2025-10-19 18:41:34 +08:00
parent ad96d676e6
commit 98a75e923d
4 changed files with 127 additions and 1 deletions

View File

@@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
import json import json
from astrbot.core import sp from astrbot.core import sp
from typing import Dict, List from typing import Dict, List, Callable, Awaitable
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2 from astrbot.core.db.po import Conversation, ConversationV2
@@ -20,6 +20,38 @@ class ConversationManager:
self.db = db_helper self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次 self.save_interval = 60 # 每 60 秒保存一次
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
def register_on_session_deleted(
self, callback: Callable[[str], Awaitable[None]]
) -> None:
"""注册会话删除回调函数
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
await callback(unified_msg_origin)
except Exception as e:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象""" """将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp()) created_at = int(conv_v2.created_at.timestamp())
@@ -106,6 +138,9 @@ class ConversationManager:
self.session_conversations.pop(unified_msg_origin, None) self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id") await sp.session_remove(unified_msg_origin, "sel_conv_id")
# 触发会话删除回调(级联清理)
await self._trigger_session_deleted(unified_msg_origin)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None: async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID """获取会话当前的对话 ID

View File

@@ -34,6 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
class AstrBotCoreLifecycle: class AstrBotCoreLifecycle:
@@ -132,6 +133,19 @@ class AstrBotCoreLifecycle:
# 根据配置实例化各个 Provider # 根据配置实例化各个 Provider
await self.provider_manager.initialize() await self.provider_manager.initialize()
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.db, self.provider_manager
)
await self.kb_manager.initialize()
# 将知识库注入器添加到 star_context 中,供 Pipeline 使用
self.star_context.kb_injector = self.kb_manager.get_kb_injector()
# 注册知识库会话生命周期钩子(零侵入级联清理)
if self.kb_manager.is_initialized:
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
# 初始化消息事件流水线调度器 # 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler() self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
@@ -233,6 +247,7 @@ class AstrBotCoreLifecycle:
await self.provider_manager.terminate() await self.provider_manager.terminate()
await self.platform_manager.terminate() await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set() self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束 # 再次遍历curr_tasks等待每个任务真正结束
@@ -248,6 +263,7 @@ class AstrBotCoreLifecycle:
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate() await self.provider_manager.terminate()
await self.platform_manager.terminate() await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set() self.dashboard_shutdown_event.set()
threading.Thread( threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True target=self.astrbot_updator._reboot, name="restart", daemon=True

View File

@@ -5,6 +5,7 @@ from astrbot.core.message.message_event_result import (
from .content_safety_check.stage import ContentSafetyCheckStage from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage from .preprocess_stage.stage import PreProcessStage
from .kb_enhance.stage import KBEnhanceStage
from .process_stage.stage import ProcessStage from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage from .respond.stage import RespondStage
@@ -21,6 +22,7 @@ STAGES_ORDER = [
"RateLimitStage", # 检查会话是否超过频率限制 "RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全 "ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理 "PreProcessStage", # 预处理
"KBEnhanceStage", # 知识库增强
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用 "ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等 "ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息 "RespondStage", # 发送消息
@@ -33,6 +35,7 @@ __all__ = [
"RateLimitStage", "RateLimitStage",
"ContentSafetyCheckStage", "ContentSafetyCheckStage",
"PreProcessStage", "PreProcessStage",
"KBEnhanceStage",
"ProcessStage", "ProcessStage",
"ResultDecorateStage", "ResultDecorateStage",
"RespondStage", "RespondStage",

View File

@@ -0,0 +1,72 @@
"""
知识库增强阶段
在 LLM 调用之前,根据会话配置注入知识库上下文
"""
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
@register_stage
class KBEnhanceStage(Stage):
"""知识库增强阶段
功能:
- 检查会话是否配置了知识库
- 如果配置了知识库,则检索相关知识并注入到事件上下文中
- 供后续的 ProcessStage 使用
"""
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.kb_config = self.config.get("knowledge_base", {})
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理知识库上下文注入"""
# 检查知识库功能是否启用
if not self.kb_config.get("enabled", False):
return
# 检查是否需要调用知识库 (只有在被@或唤醒时才检索)
if not event.is_at_or_wake_command:
return
try:
# 从 plugin_manager.context 获取 kb_injector
kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None)
if not kb_injector:
logger.debug("知识库注入器未初始化,跳过知识库增强")
return
# 获取会话 ID
unified_msg_origin = event.unified_msg_origin
# 获取用户查询
query = event.message_str
# 检索并注入知识
kb_context = await kb_injector.retrieve_and_inject(
unified_msg_origin=unified_msg_origin,
query=query,
)
if kb_context:
# 将知识库上下文存储到事件的 extra 中
event.set_extra("kb_context", kb_context)
logger.debug(
f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识"
)
except Exception as e:
logger.error(f"知识库增强阶段处理失败: {e}")
import traceback
logger.error(traceback.format_exc())