* Initial plan * Changes before error encountered Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
341 lines
13 KiB
Python
341 lines
13 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
|
|
from quart import Response as QuartResponse
|
|
from quart import g, make_response, request
|
|
|
|
from astrbot.core import logger
|
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|
from astrbot.core.db import BaseDatabase
|
|
from astrbot.core.platform.astr_message_event import MessageSession
|
|
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
|
|
from .route import Response, Route, RouteContext
|
|
|
|
|
|
@asynccontextmanager
|
|
async def track_conversation(convs: dict, conv_id: str):
|
|
convs[conv_id] = True
|
|
try:
|
|
yield
|
|
finally:
|
|
convs.pop(conv_id, None)
|
|
|
|
|
|
class ChatRoute(Route):
|
|
def __init__(
|
|
self,
|
|
context: RouteContext,
|
|
db: BaseDatabase,
|
|
core_lifecycle: AstrBotCoreLifecycle,
|
|
) -> None:
|
|
super().__init__(context)
|
|
self.routes = {
|
|
"/chat/send": ("POST", self.chat),
|
|
"/chat/new_conversation": ("GET", self.new_conversation),
|
|
"/chat/conversations": ("GET", self.get_conversations),
|
|
"/chat/get_conversation": ("GET", self.get_conversation),
|
|
"/chat/delete_conversation": ("GET", self.delete_conversation),
|
|
"/chat/rename_conversation": ("POST", self.rename_conversation),
|
|
"/chat/get_file": ("GET", self.get_file),
|
|
"/chat/post_image": ("POST", self.post_image),
|
|
"/chat/post_file": ("POST", self.post_file),
|
|
}
|
|
self.core_lifecycle = core_lifecycle
|
|
self.register_routes()
|
|
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
|
os.makedirs(self.imgs_dir, exist_ok=True)
|
|
|
|
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
|
self.conv_mgr = core_lifecycle.conversation_manager
|
|
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
|
|
|
self.running_convs: dict[str, bool] = {}
|
|
|
|
async def get_file(self):
|
|
filename = request.args.get("filename")
|
|
if not filename:
|
|
return Response().error("Missing key: filename").__dict__
|
|
|
|
try:
|
|
file_path = os.path.join(self.imgs_dir, os.path.basename(filename))
|
|
real_file_path = os.path.realpath(file_path)
|
|
real_imgs_dir = os.path.realpath(self.imgs_dir)
|
|
|
|
if not real_file_path.startswith(real_imgs_dir):
|
|
return Response().error("Invalid file path").__dict__
|
|
|
|
with open(real_file_path, "rb") as f:
|
|
filename_ext = os.path.splitext(filename)[1].lower()
|
|
|
|
if filename_ext == ".wav":
|
|
return QuartResponse(f.read(), mimetype="audio/wav")
|
|
if filename_ext[1:] in self.supported_imgs:
|
|
return QuartResponse(f.read(), mimetype="image/jpeg")
|
|
return QuartResponse(f.read())
|
|
|
|
except (FileNotFoundError, OSError):
|
|
return Response().error("File access error").__dict__
|
|
|
|
async def post_image(self):
|
|
post_data = await request.files
|
|
if "file" not in post_data:
|
|
return Response().error("Missing key: file").__dict__
|
|
|
|
file = post_data["file"]
|
|
filename = str(uuid.uuid4()) + ".jpg"
|
|
path = os.path.join(self.imgs_dir, filename)
|
|
await file.save(path)
|
|
|
|
return Response().ok(data={"filename": filename}).__dict__
|
|
|
|
async def post_file(self):
|
|
post_data = await request.files
|
|
if "file" not in post_data:
|
|
return Response().error("Missing key: file").__dict__
|
|
|
|
file = post_data["file"]
|
|
filename = f"{uuid.uuid4()!s}"
|
|
# 通过文件格式判断文件类型
|
|
if file.content_type.startswith("audio"):
|
|
filename += ".wav"
|
|
|
|
path = os.path.join(self.imgs_dir, filename)
|
|
await file.save(path)
|
|
|
|
return Response().ok(data={"filename": filename}).__dict__
|
|
|
|
async def chat(self):
|
|
username = g.get("username", "guest")
|
|
|
|
post_data = await request.json
|
|
if "message" not in post_data and "image_url" not in post_data:
|
|
return Response().error("Missing key: message or image_url").__dict__
|
|
|
|
if "conversation_id" not in post_data:
|
|
return Response().error("Missing key: conversation_id").__dict__
|
|
|
|
message = post_data["message"]
|
|
conversation_id = post_data["conversation_id"]
|
|
image_url = post_data.get("image_url")
|
|
audio_url = post_data.get("audio_url")
|
|
selected_provider = post_data.get("selected_provider")
|
|
selected_model = post_data.get("selected_model")
|
|
if not message and not image_url and not audio_url:
|
|
return (
|
|
Response()
|
|
.error("Message and image_url and audio_url are empty")
|
|
.__dict__
|
|
)
|
|
if not conversation_id:
|
|
return Response().error("conversation_id is empty").__dict__
|
|
|
|
# 追加用户消息
|
|
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
|
|
|
# 获取会话特定的队列
|
|
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
|
|
|
new_his = {"type": "user", "message": message}
|
|
if image_url:
|
|
new_his["image_url"] = image_url
|
|
if audio_url:
|
|
new_his["audio_url"] = audio_url
|
|
await self.platform_history_mgr.insert(
|
|
platform_id="webchat",
|
|
user_id=webchat_conv_id,
|
|
content=new_his,
|
|
sender_id=username,
|
|
sender_name=username,
|
|
)
|
|
|
|
async def stream():
|
|
client_disconnected = False
|
|
|
|
try:
|
|
async with track_conversation(self.running_convs, webchat_conv_id):
|
|
while True:
|
|
try:
|
|
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except asyncio.CancelledError:
|
|
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
|
client_disconnected = True
|
|
except Exception as e:
|
|
logger.error(f"WebChat stream error: {e}")
|
|
|
|
if not result:
|
|
continue
|
|
|
|
result_text = result["data"]
|
|
type = result.get("type")
|
|
streaming = result.get("streaming", False)
|
|
|
|
try:
|
|
if not client_disconnected:
|
|
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
|
except Exception as e:
|
|
if not client_disconnected:
|
|
logger.debug(
|
|
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}",
|
|
)
|
|
client_disconnected = True
|
|
|
|
try:
|
|
if not client_disconnected:
|
|
await asyncio.sleep(0.05)
|
|
except asyncio.CancelledError:
|
|
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
|
client_disconnected = True
|
|
|
|
if type == "end":
|
|
break
|
|
elif (
|
|
(streaming and type == "complete")
|
|
or not streaming
|
|
or type == "break"
|
|
):
|
|
# 追加机器人消息
|
|
new_his = {"type": "bot", "message": result_text}
|
|
await self.platform_history_mgr.insert(
|
|
platform_id="webchat",
|
|
user_id=webchat_conv_id,
|
|
content=new_his,
|
|
sender_id="bot",
|
|
sender_name="bot",
|
|
)
|
|
except BaseException as e:
|
|
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
|
|
|
# 将消息放入会话特定的队列
|
|
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
|
await chat_queue.put(
|
|
(
|
|
username,
|
|
webchat_conv_id,
|
|
{
|
|
"message": message,
|
|
"image_url": image_url, # list
|
|
"audio_url": audio_url,
|
|
"selected_provider": selected_provider,
|
|
"selected_model": selected_model,
|
|
},
|
|
),
|
|
)
|
|
|
|
response = await make_response(
|
|
stream(),
|
|
{
|
|
"Content-Type": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
"Transfer-Encoding": "chunked",
|
|
"Connection": "keep-alive",
|
|
},
|
|
)
|
|
response.timeout = None # fix SSE auto disconnect issue
|
|
return response
|
|
|
|
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
|
|
"""从对话 ID 中提取 WebChat 会话 ID
|
|
|
|
NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。
|
|
"""
|
|
conversation = await self.conv_mgr.get_conversation(
|
|
unified_msg_origin="webchat",
|
|
conversation_id=conversation_id,
|
|
)
|
|
if not conversation:
|
|
raise ValueError(f"Conversation with ID {conversation_id} not found.")
|
|
conv_user_id = conversation.user_id
|
|
webchat_session_id = MessageSession.from_str(conv_user_id).session_id
|
|
if "!" not in webchat_session_id:
|
|
raise ValueError(f"Invalid conv user ID: {conv_user_id}")
|
|
return webchat_session_id.split("!")[-1]
|
|
|
|
async def delete_conversation(self):
|
|
conversation_id = request.args.get("conversation_id")
|
|
if not conversation_id:
|
|
return Response().error("Missing key: conversation_id").__dict__
|
|
username = g.get("username", "guest")
|
|
|
|
# Clean up queues when deleting conversation
|
|
webchat_queue_mgr.remove_queues(conversation_id)
|
|
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
|
await self.conv_mgr.delete_conversation(
|
|
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
|
|
conversation_id=conversation_id,
|
|
)
|
|
await self.platform_history_mgr.delete(
|
|
platform_id="webchat",
|
|
user_id=webchat_conv_id,
|
|
offset_sec=99999999,
|
|
)
|
|
return Response().ok().__dict__
|
|
|
|
async def new_conversation(self):
|
|
username = g.get("username", "guest")
|
|
webchat_conv_id = str(uuid.uuid4())
|
|
conv_id = await self.conv_mgr.new_conversation(
|
|
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
|
|
platform_id="webchat",
|
|
content=[],
|
|
)
|
|
return Response().ok(data={"conversation_id": conv_id}).__dict__
|
|
|
|
async def rename_conversation(self):
|
|
post_data = await request.json
|
|
if "conversation_id" not in post_data or "title" not in post_data:
|
|
return Response().error("Missing key: conversation_id or title").__dict__
|
|
|
|
conversation_id = post_data["conversation_id"]
|
|
title = post_data["title"]
|
|
|
|
await self.conv_mgr.update_conversation(
|
|
unified_msg_origin="webchat", # fake
|
|
conversation_id=conversation_id,
|
|
title=title,
|
|
)
|
|
return Response().ok(message="重命名成功!").__dict__
|
|
|
|
async def get_conversations(self):
|
|
conversations = await self.conv_mgr.get_conversations(platform_id="webchat")
|
|
# remove content
|
|
conversations_ = []
|
|
for conv in conversations:
|
|
conv.history = None
|
|
conversations_.append(conv)
|
|
return Response().ok(data=conversations_).__dict__
|
|
|
|
async def get_conversation(self):
|
|
conversation_id = request.args.get("conversation_id")
|
|
if not conversation_id:
|
|
return Response().error("Missing key: conversation_id").__dict__
|
|
|
|
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
|
|
|
# Get platform message history
|
|
history_ls = await self.platform_history_mgr.get(
|
|
platform_id="webchat",
|
|
user_id=webchat_conv_id,
|
|
page=1,
|
|
page_size=1000,
|
|
)
|
|
|
|
history_res = [history.model_dump() for history in history_ls]
|
|
|
|
return (
|
|
Response()
|
|
.ok(
|
|
data={
|
|
"history": history_res,
|
|
"is_running": self.running_convs.get(webchat_conv_id, False),
|
|
},
|
|
)
|
|
.__dict__
|
|
)
|