Files
AstrBot/astrbot/core/utils/session_waiter.py
2025-03-06 22:13:14 +08:00

164 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
会话控制
"""
import asyncio
import time
import functools
import copy
import astrbot.core.message.components as Comp
from typing import Dict, Any, Callable, Awaitable, List
from astrbot.core.platform import AstrMessageEvent
USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
class SessionController:
"""
控制一个 Session 是否已经结束
"""
def __init__(self):
self.future = asyncio.Future()
self.current_event: asyncio.Event = None
"""当前正在等待的所用的异步事件"""
self.ts: float = None
"""上次保持(keep)开始时的时间"""
self.timeout: float | int = None
"""上次保持(keep)开始时的超时时间"""
self.history_chains: List[List[Comp.BaseMessageComponent]] = []
def stop(self, error: Exception = None):
"""立即结束这个会话"""
if not self.future.done():
if error:
self.future.set_exception(error)
else:
self.future.set_result(None)
def keep(self, timeout: float | int = 0, reset_timeout=False):
"""保持这个会话
Args:
timeout (float): 必填。会话超时时间。
当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。
当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0)
"""
new_ts = time.time()
if reset_timeout:
if timeout <= 0:
self.stop()
return
else:
left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout
if timeout <= 0:
self.stop()
return
if self.current_event and not self.current_event.is_set():
self.current_event.set() # 通知上一个 keep 结束
new_event = asyncio.Event()
self.ts = new_ts
self.current_event = new_event
self.timeout = timeout
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout: int):
"""等待事件结束或超时"""
try:
await asyncio.wait_for(event.wait(), timeout)
except asyncio.TimeoutError:
if not self.future.done():
self.future.set_exception(TimeoutError("等待超时"))
except asyncio.CancelledError:
pass # 避免报错
# finally:
def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]:
"""获取历史消息链"""
return self.history_chains
class SessionWaiter:
def __init__(self, session_id: str, record_history_chains: bool):
self.session_id = session_id
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
self.session_controller = SessionController()
self.record_history_chains = record_history_chains
"""是否记录历史消息链"""
self._lock = asyncio.Lock()
"""需要保证一个 session 同时只有一个 trigger"""
async def register_wait(
self, handler: Callable[[str], Awaitable[Any]], timeout: int = 30
) -> Any:
"""等待外部输入并处理"""
self.handler = handler
USER_SESSIONS[self.session_id] = self
# 开始一个会话保持事件
self.session_controller.keep(timeout, reset_timeout=True)
try:
return await self.session_controller.future
except Exception as e:
self._cleanup(e)
raise e
finally:
self._cleanup()
def _cleanup(self, error: Exception = None):
"""清理会话"""
USER_SESSIONS.pop(self.session_id, None)
self.session_controller.stop(error)
@classmethod
async def trigger(cls, session_id: str, event: AstrMessageEvent):
"""外部输入触发会话处理"""
session = USER_SESSIONS.get(session_id, None)
if not session or session.session_controller.future.done():
return
async with session._lock:
if not session.session_controller.future.done():
if session.record_history_chains:
session.session_controller.history_chains.append(
[copy.deepcopy(comp) for comp in event.get_messages()]
)
try:
# TODO: 这里使用 create_task跟踪 task防止超时后这里 handler 仍然在执行
await session.handler(session.session_controller, event)
except Exception as e:
session.session_controller.stop(e)
def session_waiter(session_id_param: str, timeout: int = 30, record_history_chains: bool = False):
"""
装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。
:param session_id_param: 用于从参数中获取 session_id 的参数名称
:param timeout: 超时时间(秒)
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
"""
def decorator(func: Callable[[str], Awaitable[Any]]):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_id = kwargs.get(session_id_param)
if not session_id:
raise ValueError(f"缺少 session_id 参数 '{session_id_param}'")
waiter = SessionWaiter(session_id, record_history_chains)
return await waiter.register_wait(func, timeout)
return wrapper
return decorator