feat: 集成知识库到核心生命周期和消息流水线
- 在 AstrBotCoreLifecycle 中初始化知识库管理器 - 将知识库注入器添加到消息处理上下文 - 在消息流水线中添加 KBEnhanceStage(知识库增强阶段) - 实现会话删除时的知识库配置级联清理机制 - 添加会话管理器的回调注册机制,支持零侵入扩展
This commit is contained in:
@@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
||||
|
||||
import json
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Callable, Awaitable
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation, ConversationV2
|
||||
|
||||
@@ -20,6 +20,38 @@ class ConversationManager:
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
|
||||
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
||||
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
||||
|
||||
def register_on_session_deleted(
|
||||
self, callback: Callable[[str], Awaitable[None]]
|
||||
) -> None:
|
||||
"""注册会话删除回调函数
|
||||
|
||||
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
||||
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
||||
"""
|
||||
self._on_session_deleted_callbacks.append(callback)
|
||||
|
||||
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
||||
"""触发会话删除回调
|
||||
|
||||
Args:
|
||||
unified_msg_origin: 会话ID
|
||||
"""
|
||||
for callback in self._on_session_deleted_callbacks:
|
||||
try:
|
||||
await callback(unified_msg_origin)
|
||||
except Exception as e:
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.error(
|
||||
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
|
||||
)
|
||||
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||
created_at = int(conv_v2.created_at.timestamp())
|
||||
@@ -106,6 +138,9 @@ class ConversationManager:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
# 触发会话删除回调(级联清理)
|
||||
await self._trigger_session_deleted(unified_msg_origin)
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
@@ -132,6 +133,19 @@ class AstrBotCoreLifecycle:
|
||||
# 根据配置实例化各个 Provider
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(
|
||||
self.astrbot_config, self.db, self.provider_manager
|
||||
)
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
# 将知识库注入器添加到 star_context 中,供 Pipeline 使用
|
||||
self.star_context.kb_injector = self.kb_manager.get_kb_injector()
|
||||
|
||||
# 注册知识库会话生命周期钩子(零侵入级联清理)
|
||||
if self.kb_manager.is_initialized:
|
||||
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
@@ -233,6 +247,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
|
||||
# 再次遍历curr_tasks等待每个任务真正结束
|
||||
@@ -248,6 +263,7 @@ class AstrBotCoreLifecycle:
|
||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
threading.Thread(
|
||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.message.message_event_result import (
|
||||
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .kb_enhance.stage import KBEnhanceStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
@@ -21,6 +22,7 @@ STAGES_ORDER = [
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"KBEnhanceStage", # 知识库增强
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage", # 发送消息
|
||||
@@ -33,6 +35,7 @@ __all__ = [
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"KBEnhanceStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RespondStage",
|
||||
|
||||
72
astrbot/core/pipeline/kb_enhance/stage.py
Normal file
72
astrbot/core/pipeline/kb_enhance/stage.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
知识库增强阶段
|
||||
在 LLM 调用之前,根据会话配置注入知识库上下文
|
||||
"""
|
||||
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class KBEnhanceStage(Stage):
|
||||
"""知识库增强阶段
|
||||
|
||||
功能:
|
||||
- 检查会话是否配置了知识库
|
||||
- 如果配置了知识库,则检索相关知识并注入到事件上下文中
|
||||
- 供后续的 ProcessStage 使用
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.kb_config = self.config.get("knowledge_base", {})
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""处理知识库上下文注入"""
|
||||
|
||||
# 检查知识库功能是否启用
|
||||
if not self.kb_config.get("enabled", False):
|
||||
return
|
||||
|
||||
# 检查是否需要调用知识库 (只有在被@或唤醒时才检索)
|
||||
if not event.is_at_or_wake_command:
|
||||
return
|
||||
|
||||
try:
|
||||
# 从 plugin_manager.context 获取 kb_injector
|
||||
kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None)
|
||||
|
||||
if not kb_injector:
|
||||
logger.debug("知识库注入器未初始化,跳过知识库增强")
|
||||
return
|
||||
|
||||
# 获取会话 ID
|
||||
unified_msg_origin = event.unified_msg_origin
|
||||
|
||||
# 获取用户查询
|
||||
query = event.message_str
|
||||
|
||||
# 检索并注入知识
|
||||
kb_context = await kb_injector.retrieve_and_inject(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
query=query,
|
||||
)
|
||||
|
||||
if kb_context:
|
||||
# 将知识库上下文存储到事件的 extra 中
|
||||
event.set_extra("kb_context", kb_context)
|
||||
logger.debug(
|
||||
f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"知识库增强阶段处理失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
Reference in New Issue
Block a user