* refactor: llm message schema * feat: implement MCPTool and local LLM tools with enhanced context handling * refactor: reorganize imports and enhance docstrings for clarity * refactor: enhance ContentPart validation and add message pair handling in ConversationManager * chore: ruff format * refactor: remove debug print statement from payloads in ProviderOpenAIOfficial * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/pipeline/process_stage/method/llm_request.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: enhance documentation and import mcp in tool.py; update call method return type * fix: 修复以数据类的方式注册 tool 时的插件重载机制问题 * refactor: change role attributes to use Literal types for message segments * fix: add support for 'decorator_handler' method in call_local_llm_tool * fix: handle None prompt in text_chat method and ensure context is properly formatted --------- Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
410 lines
15 KiB
Python
410 lines
15 KiB
Python
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.
|
||
|
||
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||
"""
|
||
|
||
import json
|
||
from collections.abc import Awaitable, Callable
|
||
|
||
from astrbot.core import sp
|
||
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
|
||
from astrbot.core.db import BaseDatabase
|
||
from astrbot.core.db.po import Conversation, ConversationV2
|
||
|
||
|
||
class ConversationManager:
|
||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||
|
||
def __init__(self, db_helper: BaseDatabase):
|
||
self.session_conversations: dict[str, str] = {}
|
||
self.db = db_helper
|
||
self.save_interval = 60 # 每 60 秒保存一次
|
||
|
||
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
||
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
|
||
|
||
def register_on_session_deleted(
|
||
self,
|
||
callback: Callable[[str], Awaitable[None]],
|
||
) -> None:
|
||
"""注册会话删除回调函数.
|
||
|
||
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
||
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
||
|
||
Args:
|
||
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
||
|
||
"""
|
||
self._on_session_deleted_callbacks.append(callback)
|
||
|
||
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
||
"""触发会话删除回调.
|
||
|
||
Args:
|
||
unified_msg_origin: 会话ID
|
||
|
||
"""
|
||
for callback in self._on_session_deleted_callbacks:
|
||
try:
|
||
await callback(unified_msg_origin)
|
||
except Exception as e:
|
||
from astrbot.core import logger
|
||
|
||
logger.error(
|
||
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
|
||
)
|
||
|
||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||
created_at = int(conv_v2.created_at.timestamp())
|
||
updated_at = int(conv_v2.updated_at.timestamp())
|
||
return Conversation(
|
||
platform_id=conv_v2.platform_id,
|
||
user_id=conv_v2.user_id,
|
||
cid=conv_v2.conversation_id,
|
||
history=json.dumps(conv_v2.content or []),
|
||
title=conv_v2.title,
|
||
persona_id=conv_v2.persona_id,
|
||
created_at=created_at,
|
||
updated_at=updated_at,
|
||
)
|
||
|
||
async def new_conversation(
|
||
self,
|
||
unified_msg_origin: str,
|
||
platform_id: str | None = None,
|
||
content: list[dict] | None = None,
|
||
title: str | None = None,
|
||
persona_id: str | None = None,
|
||
) -> str:
|
||
"""新建对话,并将当前会话的对话转移到新对话.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
Returns:
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
|
||
"""
|
||
if not platform_id:
|
||
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
||
parts = unified_msg_origin.split(":")
|
||
if len(parts) >= 3:
|
||
platform_id = parts[0]
|
||
if not platform_id:
|
||
platform_id = "unknown"
|
||
conv = await self.db.create_conversation(
|
||
user_id=unified_msg_origin,
|
||
platform_id=platform_id,
|
||
content=content,
|
||
title=title,
|
||
persona_id=persona_id,
|
||
)
|
||
self.session_conversations[unified_msg_origin] = conv.conversation_id
|
||
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
||
return conv.conversation_id
|
||
|
||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||
"""切换会话的对话
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
|
||
"""
|
||
self.session_conversations[unified_msg_origin] = conversation_id
|
||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||
|
||
async def delete_conversation(
|
||
self,
|
||
unified_msg_origin: str,
|
||
conversation_id: str | None = None,
|
||
):
|
||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
|
||
"""
|
||
if not conversation_id:
|
||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||
if conversation_id:
|
||
await self.db.delete_conversation(cid=conversation_id)
|
||
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
||
if curr_cid == conversation_id:
|
||
self.session_conversations.pop(unified_msg_origin, None)
|
||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||
|
||
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
||
"""删除会话的所有对话
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
|
||
"""
|
||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||
self.session_conversations.pop(unified_msg_origin, None)
|
||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||
|
||
# 触发会话删除回调(级联清理)
|
||
await self._trigger_session_deleted(unified_msg_origin)
|
||
|
||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||
"""获取会话当前的对话 ID
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
Returns:
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
|
||
"""
|
||
ret = self.session_conversations.get(unified_msg_origin, None)
|
||
if not ret:
|
||
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
|
||
if ret:
|
||
self.session_conversations[unified_msg_origin] = ret
|
||
return ret
|
||
|
||
async def get_conversation(
|
||
self,
|
||
unified_msg_origin: str,
|
||
conversation_id: str,
|
||
create_if_not_exists: bool = False,
|
||
) -> Conversation | None:
|
||
"""获取会话的对话.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
|
||
Returns:
|
||
conversation (Conversation): 对话对象
|
||
|
||
"""
|
||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||
if not conv and create_if_not_exists:
|
||
# 如果对话不存在且需要创建,则新建一个对话
|
||
conversation_id = await self.new_conversation(unified_msg_origin)
|
||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||
conv_res = None
|
||
if conv:
|
||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||
return conv_res
|
||
|
||
async def get_conversations(
|
||
self,
|
||
unified_msg_origin: str | None = None,
|
||
platform_id: str | None = None,
|
||
) -> list[Conversation]:
|
||
"""获取对话列表.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
||
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
||
Returns:
|
||
conversations (List[Conversation]): 对话对象列表
|
||
|
||
"""
|
||
convs = await self.db.get_conversations(
|
||
user_id=unified_msg_origin,
|
||
platform_id=platform_id,
|
||
)
|
||
convs_res = []
|
||
for conv in convs:
|
||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||
convs_res.append(conv_res)
|
||
return convs_res
|
||
|
||
async def get_filtered_conversations(
|
||
self,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
platform_ids: list[str] | None = None,
|
||
search_query: str = "",
|
||
**kwargs,
|
||
) -> tuple[list[Conversation], int]:
|
||
"""获取过滤后的对话列表.
|
||
|
||
Args:
|
||
page (int): 页码, 默认为 1
|
||
page_size (int): 每页大小, 默认为 20
|
||
platform_ids (list[str]): 平台 ID 列表, 可选
|
||
search_query (str): 搜索查询字符串, 可选
|
||
Returns:
|
||
conversations (list[Conversation]): 对话对象列表
|
||
|
||
"""
|
||
convs, cnt = await self.db.get_filtered_conversations(
|
||
page=page,
|
||
page_size=page_size,
|
||
platform_ids=platform_ids,
|
||
search_query=search_query,
|
||
**kwargs,
|
||
)
|
||
convs_res = []
|
||
for conv in convs:
|
||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||
convs_res.append(conv_res)
|
||
return convs_res, cnt
|
||
|
||
async def update_conversation(
|
||
self,
|
||
unified_msg_origin: str,
|
||
conversation_id: str | None = None,
|
||
history: list[dict] | None = None,
|
||
title: str | None = None,
|
||
persona_id: str | None = None,
|
||
) -> None:
|
||
"""更新会话的对话.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||
|
||
"""
|
||
if not conversation_id:
|
||
# 如果没有提供 conversation_id,则获取当前的
|
||
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
|
||
if conversation_id:
|
||
await self.db.update_conversation(
|
||
cid=conversation_id,
|
||
title=title,
|
||
persona_id=persona_id,
|
||
content=history,
|
||
)
|
||
|
||
async def update_conversation_title(
|
||
self,
|
||
unified_msg_origin: str,
|
||
title: str,
|
||
conversation_id: str | None = None,
|
||
) -> None:
|
||
"""更新会话的对话标题.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
title (str): 对话标题
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
Deprecated:
|
||
Use `update_conversation` with `title` parameter instead.
|
||
|
||
"""
|
||
await self.update_conversation(
|
||
unified_msg_origin=unified_msg_origin,
|
||
conversation_id=conversation_id,
|
||
title=title,
|
||
)
|
||
|
||
async def update_conversation_persona_id(
|
||
self,
|
||
unified_msg_origin: str,
|
||
persona_id: str,
|
||
conversation_id: str | None = None,
|
||
) -> None:
|
||
"""更新会话的对话 Persona ID.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
persona_id (str): 对话 Persona ID
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
Deprecated:
|
||
Use `update_conversation` with `persona_id` parameter instead.
|
||
|
||
"""
|
||
await self.update_conversation(
|
||
unified_msg_origin=unified_msg_origin,
|
||
conversation_id=conversation_id,
|
||
persona_id=persona_id,
|
||
)
|
||
|
||
async def add_message_pair(
|
||
self,
|
||
cid: str,
|
||
user_message: UserMessageSegment | dict,
|
||
assistant_message: AssistantMessageSegment | dict,
|
||
) -> None:
|
||
"""Add a user-assistant message pair to the conversation history.
|
||
|
||
Args:
|
||
cid (str): Conversation ID
|
||
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
|
||
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
|
||
|
||
Raises:
|
||
Exception: If the conversation with the given ID is not found
|
||
"""
|
||
conv = await self.db.get_conversation_by_id(cid=cid)
|
||
if not conv:
|
||
raise Exception(f"Conversation with id {cid} not found")
|
||
history = conv.content or []
|
||
if isinstance(user_message, UserMessageSegment):
|
||
user_msg_dict = user_message.model_dump()
|
||
else:
|
||
user_msg_dict = user_message
|
||
if isinstance(assistant_message, AssistantMessageSegment):
|
||
assistant_msg_dict = assistant_message.model_dump()
|
||
else:
|
||
assistant_msg_dict = assistant_message
|
||
history.append(user_msg_dict)
|
||
history.append(assistant_msg_dict)
|
||
await self.db.update_conversation(
|
||
cid=cid,
|
||
content=history,
|
||
)
|
||
|
||
async def get_human_readable_context(
|
||
self,
|
||
unified_msg_origin: str,
|
||
conversation_id: str,
|
||
page: int = 1,
|
||
page_size: int = 10,
|
||
) -> tuple[list[str], int]:
|
||
"""获取人类可读的上下文.
|
||
|
||
Args:
|
||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||
page (int): 页码
|
||
page_size (int): 每页大小
|
||
|
||
"""
|
||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||
if not conversation:
|
||
return [], 0
|
||
history = json.loads(conversation.history)
|
||
|
||
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
|
||
# 之后会被展平成一个扁平的 str 列表返回。
|
||
contexts_groups: list[list[str]] = []
|
||
temp_contexts: list[str] = []
|
||
for record in history:
|
||
if record["role"] == "user":
|
||
temp_contexts.append(f"User: {record['content']}")
|
||
elif record["role"] == "assistant":
|
||
if record.get("content"):
|
||
temp_contexts.append(f"Assistant: {record['content']}")
|
||
elif "tool_calls" in record:
|
||
tool_calls_str = json.dumps(
|
||
record["tool_calls"],
|
||
ensure_ascii=False,
|
||
)
|
||
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||
else:
|
||
temp_contexts.append("Assistant: [未知的内容]")
|
||
contexts_groups.insert(0, temp_contexts)
|
||
temp_contexts = []
|
||
|
||
# 展平分组后的 contexts 列表为单层字符串列表
|
||
contexts = [item for sublist in contexts_groups for item in sublist]
|
||
|
||
# 计算分页
|
||
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||
total_pages = len(contexts) // page_size
|
||
if len(contexts) % page_size != 0:
|
||
total_pages += 1
|
||
|
||
return paged_contexts, total_pages
|