feat: 集成知识库到核心生命周期和消息流水线
- 在 AstrBotCoreLifecycle 中初始化知识库管理器 - 将知识库注入器添加到消息处理上下文 - 在消息流水线中添加 KBEnhanceStage(知识库增强阶段) - 实现会话删除时的知识库配置级联清理机制 - 添加会话管理器的回调注册机制,支持零侵入扩展
This commit is contained in:
@@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Callable, Awaitable
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.db.po import Conversation, ConversationV2
|
from astrbot.core.db.po import Conversation, ConversationV2
|
||||||
|
|
||||||
@@ -20,6 +20,38 @@ class ConversationManager:
|
|||||||
self.db = db_helper
|
self.db = db_helper
|
||||||
self.save_interval = 60 # 每 60 秒保存一次
|
self.save_interval = 60 # 每 60 秒保存一次
|
||||||
|
|
||||||
|
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
||||||
|
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
||||||
|
|
||||||
|
def register_on_session_deleted(
|
||||||
|
self, callback: Callable[[str], Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
|
"""注册会话删除回调函数
|
||||||
|
|
||||||
|
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
||||||
|
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
||||||
|
"""
|
||||||
|
self._on_session_deleted_callbacks.append(callback)
|
||||||
|
|
||||||
|
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
||||||
|
"""触发会话删除回调
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin: 会话ID
|
||||||
|
"""
|
||||||
|
for callback in self._on_session_deleted_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(unified_msg_origin)
|
||||||
|
except Exception as e:
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||||
created_at = int(conv_v2.created_at.timestamp())
|
created_at = int(conv_v2.created_at.timestamp())
|
||||||
@@ -106,6 +138,9 @@ class ConversationManager:
|
|||||||
self.session_conversations.pop(unified_msg_origin, None)
|
self.session_conversations.pop(unified_msg_origin, None)
|
||||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||||
|
|
||||||
|
# 触发会话删除回调(级联清理)
|
||||||
|
await self._trigger_session_deleted(unified_msg_origin)
|
||||||
|
|
||||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||||
"""获取会话当前的对话 ID
|
"""获取会话当前的对话 ID
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryMana
|
|||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star_handler import star_map
|
from astrbot.core.star.star_handler import star_map
|
||||||
|
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||||
|
|
||||||
|
|
||||||
class AstrBotCoreLifecycle:
|
class AstrBotCoreLifecycle:
|
||||||
@@ -132,6 +133,19 @@ class AstrBotCoreLifecycle:
|
|||||||
# 根据配置实例化各个 Provider
|
# 根据配置实例化各个 Provider
|
||||||
await self.provider_manager.initialize()
|
await self.provider_manager.initialize()
|
||||||
|
|
||||||
|
# 初始化知识库管理器
|
||||||
|
self.kb_manager = KnowledgeBaseManager(
|
||||||
|
self.astrbot_config, self.db, self.provider_manager
|
||||||
|
)
|
||||||
|
await self.kb_manager.initialize()
|
||||||
|
|
||||||
|
# 将知识库注入器添加到 star_context 中,供 Pipeline 使用
|
||||||
|
self.star_context.kb_injector = self.kb_manager.get_kb_injector()
|
||||||
|
|
||||||
|
# 注册知识库会话生命周期钩子(零侵入级联清理)
|
||||||
|
if self.kb_manager.is_initialized:
|
||||||
|
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
|
||||||
|
|
||||||
# 初始化消息事件流水线调度器
|
# 初始化消息事件流水线调度器
|
||||||
|
|
||||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||||
@@ -233,6 +247,7 @@ class AstrBotCoreLifecycle:
|
|||||||
|
|
||||||
await self.provider_manager.terminate()
|
await self.provider_manager.terminate()
|
||||||
await self.platform_manager.terminate()
|
await self.platform_manager.terminate()
|
||||||
|
await self.kb_manager.terminate()
|
||||||
self.dashboard_shutdown_event.set()
|
self.dashboard_shutdown_event.set()
|
||||||
|
|
||||||
# 再次遍历curr_tasks等待每个任务真正结束
|
# 再次遍历curr_tasks等待每个任务真正结束
|
||||||
@@ -248,6 +263,7 @@ class AstrBotCoreLifecycle:
|
|||||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||||
await self.provider_manager.terminate()
|
await self.provider_manager.terminate()
|
||||||
await self.platform_manager.terminate()
|
await self.platform_manager.terminate()
|
||||||
|
await self.kb_manager.terminate()
|
||||||
self.dashboard_shutdown_event.set()
|
self.dashboard_shutdown_event.set()
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from astrbot.core.message.message_event_result import (
|
|||||||
|
|
||||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||||
from .preprocess_stage.stage import PreProcessStage
|
from .preprocess_stage.stage import PreProcessStage
|
||||||
|
from .kb_enhance.stage import KBEnhanceStage
|
||||||
from .process_stage.stage import ProcessStage
|
from .process_stage.stage import ProcessStage
|
||||||
from .rate_limit_check.stage import RateLimitStage
|
from .rate_limit_check.stage import RateLimitStage
|
||||||
from .respond.stage import RespondStage
|
from .respond.stage import RespondStage
|
||||||
@@ -21,6 +22,7 @@ STAGES_ORDER = [
|
|||||||
"RateLimitStage", # 检查会话是否超过频率限制
|
"RateLimitStage", # 检查会话是否超过频率限制
|
||||||
"ContentSafetyCheckStage", # 检查内容安全
|
"ContentSafetyCheckStage", # 检查内容安全
|
||||||
"PreProcessStage", # 预处理
|
"PreProcessStage", # 预处理
|
||||||
|
"KBEnhanceStage", # 知识库增强
|
||||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||||
"RespondStage", # 发送消息
|
"RespondStage", # 发送消息
|
||||||
@@ -33,6 +35,7 @@ __all__ = [
|
|||||||
"RateLimitStage",
|
"RateLimitStage",
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
"PreProcessStage",
|
"PreProcessStage",
|
||||||
|
"KBEnhanceStage",
|
||||||
"ProcessStage",
|
"ProcessStage",
|
||||||
"ResultDecorateStage",
|
"ResultDecorateStage",
|
||||||
"RespondStage",
|
"RespondStage",
|
||||||
|
|||||||
72
astrbot/core/pipeline/kb_enhance/stage.py
Normal file
72
astrbot/core/pipeline/kb_enhance/stage.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
知识库增强阶段
|
||||||
|
在 LLM 调用之前,根据会话配置注入知识库上下文
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Union, AsyncGenerator
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
@register_stage
|
||||||
|
class KBEnhanceStage(Stage):
|
||||||
|
"""知识库增强阶段
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 检查会话是否配置了知识库
|
||||||
|
- 如果配置了知识库,则检索相关知识并注入到事件上下文中
|
||||||
|
- 供后续的 ProcessStage 使用
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
self.ctx = ctx
|
||||||
|
self.config = ctx.astrbot_config
|
||||||
|
self.kb_config = self.config.get("knowledge_base", {})
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
"""处理知识库上下文注入"""
|
||||||
|
|
||||||
|
# 检查知识库功能是否启用
|
||||||
|
if not self.kb_config.get("enabled", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否需要调用知识库 (只有在被@或唤醒时才检索)
|
||||||
|
if not event.is_at_or_wake_command:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 从 plugin_manager.context 获取 kb_injector
|
||||||
|
kb_injector = getattr(self.ctx.plugin_manager.context, "kb_injector", None)
|
||||||
|
|
||||||
|
if not kb_injector:
|
||||||
|
logger.debug("知识库注入器未初始化,跳过知识库增强")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取会话 ID
|
||||||
|
unified_msg_origin = event.unified_msg_origin
|
||||||
|
|
||||||
|
# 获取用户查询
|
||||||
|
query = event.message_str
|
||||||
|
|
||||||
|
# 检索并注入知识
|
||||||
|
kb_context = await kb_injector.retrieve_and_inject(
|
||||||
|
unified_msg_origin=unified_msg_origin,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
if kb_context:
|
||||||
|
# 将知识库上下文存储到事件的 extra 中
|
||||||
|
event.set_extra("kb_context", kb_context)
|
||||||
|
logger.debug(
|
||||||
|
f"知识库上下文已注入,检索到 {len(kb_context.get('results', []))} 条相关知识"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"知识库增强阶段处理失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
Reference in New Issue
Block a user