Files
AstrBot/astrbot/core/conversation_mgr.py

200 lines
8.2 KiB
Python

"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self, unified_msg_origin: str, conversation_id: str
) -> Conversation:
"""获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Returns:
conversation (Conversation): 对话对象
"""
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
):
"""更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history),
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin, cid=conversation_id, title=title
)
async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
)
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10
):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"], ensure_ascii=False
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages