842c3c8ea9
* stage * stage * refactor: using sqlchemy as ORM framework, switch to async-based sqlite operation - using sqlmodel as ORM(based on sqlchemy and pydantic) - add Persona, Preference, PlatformMessageHistory table * fix: conversation * fix: remove redundant explicit session.commit, and fix some type error * fix: conversation context issue * chore: remove comments * chore: remove exclude_content param
229 lines
8.7 KiB
Python
229 lines
8.7 KiB
Python
import traceback
|
|
import json
|
|
from .route import Route, Response, RouteContext
|
|
from astrbot.core import logger
|
|
from quart import request
|
|
from astrbot.core.db import BaseDatabase
|
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|
|
|
|
|
class ConversationRoute(Route):
|
|
def __init__(
|
|
self,
|
|
context: RouteContext,
|
|
db_helper: BaseDatabase,
|
|
core_lifecycle: AstrBotCoreLifecycle,
|
|
) -> None:
|
|
super().__init__(context)
|
|
self.routes = {
|
|
"/conversation/list": ("GET", self.list_conversations),
|
|
"/conversation/detail": (
|
|
"POST",
|
|
self.get_conv_detail,
|
|
),
|
|
"/conversation/update": ("POST", self.upd_conv),
|
|
"/conversation/delete": ("POST", self.del_conv),
|
|
"/conversation/update_history": (
|
|
"POST",
|
|
self.update_history,
|
|
),
|
|
}
|
|
self.db_helper = db_helper
|
|
self.conv_mgr = core_lifecycle.conversation_manager
|
|
self.core_lifecycle = core_lifecycle
|
|
self.register_routes()
|
|
|
|
async def list_conversations(self):
|
|
"""获取对话列表,支持分页、排序和筛选"""
|
|
try:
|
|
# 获取分页参数
|
|
page = request.args.get("page", 1, type=int)
|
|
page_size = request.args.get("page_size", 20, type=int)
|
|
|
|
# 获取筛选参数
|
|
platforms = request.args.get("platforms", "")
|
|
message_types = request.args.get("message_types", "")
|
|
search_query = request.args.get("search", "")
|
|
exclude_ids = request.args.get("exclude_ids", "")
|
|
exclude_platforms = request.args.get("exclude_platforms", "")
|
|
|
|
# 转换为列表
|
|
platform_list = platforms.split(",") if platforms else []
|
|
message_type_list = message_types.split(",") if message_types else []
|
|
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
|
|
exclude_platform_list = (
|
|
exclude_platforms.split(",") if exclude_platforms else []
|
|
)
|
|
|
|
if page < 1:
|
|
page = 1
|
|
if page_size < 1:
|
|
page_size = 20
|
|
if page_size > 100:
|
|
page_size = 100
|
|
|
|
try:
|
|
(
|
|
conversations,
|
|
total_count,
|
|
) = await self.conv_mgr.get_filtered_conversations(
|
|
page=page,
|
|
page_size=page_size,
|
|
platforms=platform_list,
|
|
message_types=message_type_list,
|
|
search_query=search_query,
|
|
exclude_ids=exclude_id_list,
|
|
exclude_platforms=exclude_platform_list,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"数据库查询出错: {str(e)}\n{traceback.format_exc()}")
|
|
return Response().error(f"数据库查询出错: {str(e)}").__dict__
|
|
|
|
# 计算总页数
|
|
total_pages = (
|
|
(total_count + page_size - 1) // page_size if total_count > 0 else 1
|
|
)
|
|
|
|
result = {
|
|
"conversations": conversations,
|
|
"pagination": {
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"total": total_count,
|
|
"total_pages": total_pages,
|
|
},
|
|
}
|
|
return Response().ok(result).__dict__
|
|
|
|
except Exception as e:
|
|
error_msg = f"获取对话列表失败: {str(e)}\n{traceback.format_exc()}"
|
|
logger.error(error_msg)
|
|
return Response().error(f"获取对话列表失败: {str(e)}").__dict__
|
|
|
|
async def get_conv_detail(self):
|
|
"""获取指定对话详情(通过POST请求)"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id, conversation_id=cid
|
|
)
|
|
if not conversation:
|
|
return Response().error("对话不存在").__dict__
|
|
|
|
return (
|
|
Response()
|
|
.ok(
|
|
{
|
|
"user_id": user_id,
|
|
"cid": cid,
|
|
"title": conversation.title,
|
|
"persona_id": conversation.persona_id,
|
|
"history": conversation.history,
|
|
"created_at": conversation.created_at,
|
|
"updated_at": conversation.updated_at,
|
|
}
|
|
)
|
|
.__dict__
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取对话详情失败: {str(e)}\n{traceback.format_exc()}")
|
|
return Response().error(f"获取对话详情失败: {str(e)}").__dict__
|
|
|
|
async def upd_conv(self):
|
|
"""更新对话信息(标题和角色ID)"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
title = data.get("title")
|
|
persona_id = data.get("persona_id", "")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id, conversation_id=cid
|
|
)
|
|
if not conversation:
|
|
return Response().error("对话不存在").__dict__
|
|
if title is not None or persona_id is not None:
|
|
await self.conv_mgr.update_conversation(
|
|
unified_msg_origin=user_id,
|
|
conversation_id=cid,
|
|
title=title,
|
|
persona_id=persona_id,
|
|
)
|
|
return Response().ok({"message": "对话信息更新成功"}).__dict__
|
|
|
|
except Exception as e:
|
|
logger.error(f"更新对话信息失败: {str(e)}\n{traceback.format_exc()}")
|
|
return Response().error(f"更新对话信息失败: {str(e)}").__dict__
|
|
|
|
async def del_conv(self):
|
|
"""删除对话"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
|
unified_msg_origin=user_id, conversation_id=cid
|
|
)
|
|
return Response().ok({"message": "对话删除成功"}).__dict__
|
|
|
|
except Exception as e:
|
|
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
|
|
return Response().error(f"删除对话失败: {str(e)}").__dict__
|
|
|
|
async def update_history(self):
|
|
"""更新对话历史内容"""
|
|
try:
|
|
data = await request.get_json()
|
|
user_id = data.get("user_id")
|
|
cid = data.get("cid")
|
|
history = data.get("history")
|
|
|
|
if not user_id or not cid:
|
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
|
|
|
if history is None:
|
|
return Response().error("缺少必要参数: history").__dict__
|
|
|
|
# 历史记录必须是合法的 JSON 字符串
|
|
try:
|
|
if isinstance(history, list):
|
|
history = json.dumps(history)
|
|
else:
|
|
# 验证是否为有效的 JSON 字符串
|
|
json.loads(history)
|
|
except json.JSONDecodeError:
|
|
return (
|
|
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
|
|
)
|
|
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin=user_id, conversation_id=cid
|
|
)
|
|
if not conversation:
|
|
return Response().error("对话不存在").__dict__
|
|
|
|
history = json.loads(history) if isinstance(history, str) else history
|
|
|
|
await self.conv_mgr.update_conversation(
|
|
unified_msg_origin=user_id, conversation_id=cid, history=history
|
|
)
|
|
|
|
return Response().ok({"message": "对话历史更新成功"}).__dict__
|
|
|
|
except Exception as e:
|
|
logger.error(f"更新对话历史失败: {str(e)}\n{traceback.format_exc()}")
|
|
return Response().error(f"更新对话历史失败: {str(e)}").__dict__
|