Files
AstrBot/packages/astrbot/commands/conversation.py

382 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
from astrbot.api import logger, sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.message_type import MessageType
from ..long_term_memory import LongTermMemory
from .utils.rst_scene import RstScene
THIRD_PARTY_AGENT_RUNNER_KEY = {
"dify": "dify_conversation_id",
"coze": "coze_conversation_id",
}
THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
class ConversationCommands:
def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
self.context = context
self.ltm = ltm
async def _get_current_persona_id(self, session_id):
curr = await self.context.conversation_manager.get_curr_conversation_id(
session_id,
)
if not curr:
return None
conv = await self.context.conversation_manager.get_conversation(
session_id,
curr,
)
return conv.persona_id
def ltm_enabled(self, event: AstrMessageEvent):
if not self.ltm:
return False
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
"provider_ltm_settings"
]
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
async def reset(self, message: AstrMessageEvent):
"""重置 LLM 会话"""
umo = message.unified_msg_origin
cfg = self.context.get_config(umo=message.unified_msg_origin)
is_unique_session = cfg["platform_settings"]["unique_session"]
is_group = bool(message.get_group_id())
scene = RstScene.get_scene(is_group, is_unique_session)
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
plugin_config = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_config.get("reset", {})
required_perm = reset_cfg.get(
scene.key,
"admin" if is_group and not is_unique_session else "member",
)
if required_perm == "admin" and message.role != "admin":
message.set_result(
MessageEventResult().message(
f"{scene.name}场景下reset命令需要管理员权限"
f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。",
),
)
return
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
await sp.remove_async(
scope="umo",
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return
if not self.context.get_using_provider(umo):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
if not cid:
message.set_result(
MessageEventResult().message(
"当前未处于对话状态,请 /switch 切换或者 /new 创建。",
),
)
return
await self.context.conversation_manager.update_conversation(
umo,
cid,
[],
)
ret = "清除会话 LLM 聊天历史成功。"
if self.ltm and self.ltm_enabled(message):
cnt = await self.ltm.remove_session(event=message)
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
message.set_result(MessageEventResult().message(ret))
async def his(self, message: AstrMessageEvent, page: int = 1):
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
size_per_page = 6
conv_mgr = self.context.conversation_manager
umo = message.unified_msg_origin
session_curr_cid = await conv_mgr.get_curr_conversation_id(umo)
if not session_curr_cid:
session_curr_cid = await conv_mgr.new_conversation(
umo,
message.get_platform_id(),
)
contexts, total_pages = await conv_mgr.get_human_readable_context(
umo,
session_curr_cid,
page,
size_per_page,
)
parts = []
for context in contexts:
if len(context) > 150:
context = context[:150] + "..."
parts.append(f"{context}\n")
history = "".join(parts)
ret = (
f"当前对话历史记录:"
f"{history or '无历史记录'}\n\n"
f"{page} 页 | 共 {total_pages}\n"
f"*输入 /history 2 跳转到第 2 页"
)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
async def convs(self, message: AstrMessageEvent, page: int = 1):
"""查看对话列表"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
message.set_result(
MessageEventResult().message(
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
),
)
return
size_per_page = 6
"""获取所有对话列表"""
conversations_all = await self.context.conversation_manager.get_conversations(
message.unified_msg_origin,
)
"""计算总页数"""
total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page
"""确保页码有效"""
page = max(1, min(page, total_pages))
"""分页处理"""
start_idx = (page - 1) * size_per_page
end_idx = start_idx + size_per_page
conversations_paged = conversations_all[start_idx:end_idx]
parts = ["对话列表:\n---\n"]
"""全局序号从当前页的第一个开始"""
global_index = start_idx + 1
"""生成所有对话的标题字典"""
_titles = {}
for conv in conversations_all:
title = conv.title if conv.title else "新对话"
_titles[conv.cid] = title
"""遍历分页后的对话生成列表显示"""
for conv in conversations_paged:
persona_id = conv.persona_id
if not persona_id or persona_id == "[%None]":
persona = await self.context.persona_manager.get_default_persona_v3(
umo=message.unified_msg_origin,
)
persona_id = persona["name"]
title = _titles.get(conv.cid, "新对话")
parts.append(
f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
)
global_index += 1
parts.append("---\n")
ret = "".join(parts)
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
message.unified_msg_origin,
)
if curr_cid:
"""从所有对话的标题字典中获取标题"""
title = _titles.get(curr_cid, "新对话")
ret += f"\n当前对话: {title}({curr_cid[:4]})"
else:
ret += "\n当前对话: 无"
cfg = self.context.get_config(umo=message.unified_msg_origin)
unique_session = cfg["platform_settings"]["unique_session"]
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
ret += "\n会话隔离粒度: 群聊"
ret += f"\n{page} 页 | 共 {total_pages}"
ret += "\n*输入 /ls 2 跳转到第 2 页"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
return
async def new_conv(self, message: AstrMessageEvent):
"""创建新对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("已创建新对话。"))
return
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
cid = await self.context.conversation_manager.new_conversation(
message.unified_msg_origin,
message.get_platform_id(),
persona_id=cpersona,
)
# 长期记忆
if self.ltm and self.ltm_enabled(message):
try:
await self.ltm.remove_session(event=message)
except Exception as e:
logger.error(f"清理聊天增强记录失败: {e}")
message.set_result(
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
)
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
"""创建新群聊对话"""
if sid:
session = str(
MessageSession(
platform_name=message.platform_meta.id,
message_type=MessageType("GroupMessage"),
session_id=sid,
),
)
cpersona = await self._get_current_persona_id(session)
cid = await self.context.conversation_manager.new_conversation(
session,
message.get_platform_id(),
persona_id=cpersona,
)
message.set_result(
MessageEventResult().message(
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。",
),
)
else:
message.set_result(
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。"),
)
async def switch_conv(
self,
message: AstrMessageEvent,
index: int | None = None,
):
"""通过 /ls 前面的序号切换对话"""
if not isinstance(index, int):
message.set_result(
MessageEventResult().message("类型错误,请输入数字对话序号。"),
)
return
if index is None:
message.set_result(
MessageEventResult().message(
"请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话",
),
)
return
conversations = await self.context.conversation_manager.get_conversations(
message.unified_msg_origin,
)
if index > len(conversations) or index < 1:
message.set_result(
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
)
else:
conversation = conversations[index - 1]
title = conversation.title if conversation.title else "新对话"
await self.context.conversation_manager.switch_conversation(
message.unified_msg_origin,
conversation.cid,
)
message.set_result(
MessageEventResult().message(
f"切换到对话: {title}({conversation.cid[:4]})。",
),
)
async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""):
"""重命名对话"""
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
return
await self.context.conversation_manager.update_conversation_title(
message.unified_msg_origin,
new_name,
)
message.set_result(MessageEventResult().message("重命名对话成功。"))
async def del_conv(self, message: AstrMessageEvent):
"""删除当前对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
is_unique_session = cfg["platform_settings"]["unique_session"]
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
message.set_result(
MessageEventResult().message(
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。",
),
)
return
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return
session_curr_cid = (
await self.context.conversation_manager.get_curr_conversation_id(
message.unified_msg_origin,
)
)
if not session_curr_cid:
message.set_result(
MessageEventResult().message(
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建。",
),
)
return
await self.context.conversation_manager.delete_conversation(
message.unified_msg_origin,
session_curr_cid,
)
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
if self.ltm and self.ltm_enabled(message):
cnt = await self.ltm.remove_session(event=message)
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
message.set_result(MessageEventResult().message(ret))