Files
AstrBot/astrbot/core/conversation_mgr.py
Soulter 50144ddcae refactor: revise LLM message schema and fix the reload logic when using dataclass-based LLM Tool registration (#3234)
* refactor: llm message schema

* feat: implement MCPTool and local LLM tools with enhanced context handling

* refactor: reorganize imports and enhance docstrings for clarity

* refactor: enhance ContentPart validation and add message pair handling in ConversationManager

* chore: ruff format

* refactor: remove debug print statement from payloads in ProviderOpenAIOfficial

* Update astrbot/core/agent/tool.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/message.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/message.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/tool.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/pipeline/process_stage/method/llm_request.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/agent/message.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* refactor: enhance documentation and import mcp in tool.py; update call method return type

* fix: 修复以数据类的方式注册 tool 时的插件重载机制问题

* refactor: change role attributes to use Literal types for message segments

* fix: add support for 'decorator_handler' method in call_local_llm_tool

* fix: handle None prompt in text_chat method and ensure context is properly formatted

---------

Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-02 18:12:20 +08:00

410 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import json
from collections.abc import Awaitable, Callable
from astrbot.core import sp
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: dict[str, str] = {}
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
def register_on_session_deleted(
self,
callback: Callable[[str], Awaitable[None]],
) -> None:
"""注册会话删除回调函数.
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调.
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
await callback(unified_msg_origin)
except Exception as e:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
updated_at = int(conv_v2.updated_at.timestamp())
return Conversation(
platform_id=conv_v2.platform_id,
user_id=conv_v2.user_id,
cid=conv_v2.conversation_id,
history=json.dumps(conv_v2.content or []),
title=conv_v2.title,
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
)
async def new_conversation(
self,
unified_msg_origin: str,
platform_id: str | None = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> str:
"""新建对话,并将当前会话的对话转移到新对话.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析
parts = unified_msg_origin.split(":")
if len(parts) >= 3:
platform_id = parts[0]
if not platform_id:
platform_id = "unknown"
conv = await self.db.create_conversation(
user_id=unified_msg_origin,
platform_id=platform_id,
content=content,
title=title,
persona_id=persona_id,
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
async def delete_conversation(
self,
unified_msg_origin: str,
conversation_id: str | None = None,
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
# 触发会话删除回调(级联清理)
await self._trigger_session_deleted(unified_msg_origin)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
if ret:
self.session_conversations[unified_msg_origin] = ret
return ret
async def get_conversation(
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation | None:
"""获取会话的对话.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
Returns:
conversation (Conversation): 对话对象
"""
conv = await self.db.get_conversation_by_id(cid=conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv_res = None
if conv:
conv_res = self._convert_conv_from_v2_to_v1(conv)
return conv_res
async def get_conversations(
self,
unified_msg_origin: str | None = None,
platform_id: str | None = None,
) -> list[Conversation]:
"""获取对话列表.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
Returns:
conversations (List[Conversation]): 对话对象列表
"""
convs = await self.db.get_conversations(
user_id=unified_msg_origin,
platform_id=platform_id,
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表.
Args:
page (int): 页码, 默认为 1
page_size (int): 每页大小, 默认为 20
platform_ids (list[str]): 平台 ID 列表, 可选
search_query (str): 搜索查询字符串, 可选
Returns:
conversations (list[Conversation]): 对话对象列表
"""
convs, cnt = await self.db.get_filtered_conversations(
page=page,
page_size=page_size,
platform_ids=platform_ids,
search_query=search_query,
**kwargs,
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res, cnt
async def update_conversation(
self,
unified_msg_origin: str,
conversation_id: str | None = None,
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> None:
"""更新会话的对话.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id则获取当前的
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
if conversation_id:
await self.db.update_conversation(
cid=conversation_id,
title=title,
persona_id=persona_id,
content=history,
)
async def update_conversation_title(
self,
unified_msg_origin: str,
title: str,
conversation_id: str | None = None,
) -> None:
"""更新会话的对话标题.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Deprecated:
Use `update_conversation` with `title` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
title=title,
)
async def update_conversation_persona_id(
self,
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
) -> None:
"""更新会话的对话 Persona ID.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Deprecated:
Use `update_conversation` with `persona_id` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
persona_id=persona_id,
)
async def add_message_pair(
self,
cid: str,
user_message: UserMessageSegment | dict,
assistant_message: AssistantMessageSegment | dict,
) -> None:
"""Add a user-assistant message pair to the conversation history.
Args:
cid (str): Conversation ID
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
Raises:
Exception: If the conversation with the given ID is not found
"""
conv = await self.db.get_conversation_by_id(cid=cid)
if not conv:
raise Exception(f"Conversation with id {cid} not found")
history = conv.content or []
if isinstance(user_message, UserMessageSegment):
user_msg_dict = user_message.model_dump()
else:
user_msg_dict = user_message
if isinstance(assistant_message, AssistantMessageSegment):
assistant_msg_dict = assistant_message.model_dump()
else:
assistant_msg_dict = assistant_message
history.append(user_msg_dict)
history.append(assistant_msg_dict)
await self.db.update_conversation(
cid=cid,
content=history,
)
async def get_human_readable_context(
self,
unified_msg_origin: str,
conversation_id: str,
page: int = 1,
page_size: int = 10,
) -> tuple[list[str], int]:
"""获取人类可读的上下文.
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
if not conversation:
return [], 0
history = json.loads(conversation.history)
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
# 之后会被展平成一个扁平的 str 列表返回。
contexts_groups: list[list[str]] = []
temp_contexts: list[str] = []
for record in history:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
if record.get("content"):
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"],
ensure_ascii=False,
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
contexts_groups.insert(0, temp_contexts)
temp_contexts = []
# 展平分组后的 contexts 列表为单层字符串列表
contexts = [item for sublist in contexts_groups for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages