✨ feat: 支持插件会话控制 API
This commit is contained in:
7
astrbot/api/util/__init__.py
Normal file
7
astrbot/api/util/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from astrbot.core.utils.session_waiter import (
|
||||
SessionWaiter,
|
||||
SessionController,
|
||||
session_waiter,
|
||||
)
|
||||
|
||||
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]
|
||||
163
astrbot/core/utils/session_waiter.py
Normal file
163
astrbot/core/utils/session_waiter.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
会话控制
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import functools
|
||||
import copy
|
||||
import astrbot.core.message.components as Comp
|
||||
from typing import Dict, Any, Callable, Awaitable, List
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
|
||||
|
||||
|
||||
class SessionController:
|
||||
"""
|
||||
控制一个 Session 是否已经结束
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.future = asyncio.Future()
|
||||
self.current_event: asyncio.Event = None
|
||||
"""当前正在等待的所用的异步事件"""
|
||||
self.ts: float = None
|
||||
"""上次保持(keep)开始时的时间"""
|
||||
self.timeout: float | int = None
|
||||
"""上次保持(keep)开始时的超时时间"""
|
||||
|
||||
self.history_chains: List[List[Comp.BaseMessageComponent]] = []
|
||||
|
||||
def stop(self, error: Exception = None):
|
||||
"""立即结束这个会话"""
|
||||
if not self.future.done():
|
||||
if error:
|
||||
self.future.set_exception(error)
|
||||
else:
|
||||
self.future.set_result(None)
|
||||
|
||||
def keep(self, timeout: float | int = 0, reset_timeout=False):
|
||||
"""保持这个会话
|
||||
|
||||
Args:
|
||||
timeout (float): 必填。会话超时时间。
|
||||
当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。
|
||||
当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0)
|
||||
"""
|
||||
new_ts = time.time()
|
||||
|
||||
if reset_timeout:
|
||||
if timeout <= 0:
|
||||
self.stop()
|
||||
return
|
||||
else:
|
||||
left_timeout = self.timeout - (new_ts - self.ts)
|
||||
timeout = left_timeout + timeout
|
||||
if timeout <= 0:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
if self.current_event and not self.current_event.is_set():
|
||||
self.current_event.set() # 通知上一个 keep 结束
|
||||
|
||||
new_event = asyncio.Event()
|
||||
self.ts = new_ts
|
||||
self.current_event = new_event
|
||||
self.timeout = timeout
|
||||
|
||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
|
||||
async def _holding(self, event: asyncio.Event, timeout: int):
|
||||
"""等待事件结束或超时"""
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if not self.future.done():
|
||||
self.future.set_exception(TimeoutError("等待超时"))
|
||||
except asyncio.CancelledError:
|
||||
pass # 避免报错
|
||||
# finally:
|
||||
|
||||
def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]:
|
||||
"""获取历史消息链"""
|
||||
return self.history_chains
|
||||
|
||||
|
||||
class SessionWaiter:
|
||||
def __init__(self, session_id: str, record_history_chains: bool):
|
||||
self.session_id = session_id
|
||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
||||
|
||||
self.session_controller = SessionController()
|
||||
self.record_history_chains = record_history_chains
|
||||
"""是否记录历史消息链"""
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
"""需要保证一个 session 同时只有一个 trigger"""
|
||||
|
||||
async def register_wait(
|
||||
self, handler: Callable[[str], Awaitable[Any]], timeout: int = 30
|
||||
) -> Any:
|
||||
"""等待外部输入并处理"""
|
||||
self.handler = handler
|
||||
USER_SESSIONS[self.session_id] = self
|
||||
|
||||
# 开始一个会话保持事件
|
||||
self.session_controller.keep(timeout, reset_timeout=True)
|
||||
|
||||
try:
|
||||
return await self.session_controller.future
|
||||
except Exception as e:
|
||||
self._cleanup(e)
|
||||
raise e
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self, error: Exception = None):
|
||||
"""清理会话"""
|
||||
USER_SESSIONS.pop(self.session_id, None)
|
||||
self.session_controller.stop(error)
|
||||
|
||||
@classmethod
|
||||
async def trigger(cls, session_id: str, event: AstrMessageEvent):
|
||||
"""外部输入触发会话处理"""
|
||||
session = USER_SESSIONS.get(session_id, None)
|
||||
if not session or session.session_controller.future.done():
|
||||
return
|
||||
|
||||
async with session._lock:
|
||||
if not session.session_controller.future.done():
|
||||
if session.record_history_chains:
|
||||
session.session_controller.history_chains.append(
|
||||
[copy.deepcopy(comp) for comp in event.get_messages()]
|
||||
)
|
||||
try:
|
||||
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
||||
await session.handler(session.session_controller, event)
|
||||
except Exception as e:
|
||||
session.session_controller.stop(e)
|
||||
|
||||
|
||||
def session_waiter(session_id_param: str, timeout: int = 30, record_history_chains: bool = False):
|
||||
"""
|
||||
装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。
|
||||
|
||||
:param session_id_param: 用于从参数中获取 session_id 的参数名称
|
||||
:param timeout: 超时时间(秒)
|
||||
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
session_id = kwargs.get(session_id_param)
|
||||
if not session_id:
|
||||
raise ValueError(f"缺少 session_id 参数 '{session_id_param}'")
|
||||
|
||||
waiter = SessionWaiter(session_id, record_history_chains)
|
||||
return await waiter.register_wait(func, timeout)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
25
packages/session_controller/main.py
Normal file
25
packages/session_controller/main.py
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.star import Context, Star, register
|
||||
from astrbot.core.utils.session_waiter import SessionWaiter, USER_SESSIONS
|
||||
from sys import maxsize
|
||||
|
||||
@register(
|
||||
"session_controller",
|
||||
"Cvandia & Soulter",
|
||||
"为插件支持会话控制",
|
||||
"v1.0.1",
|
||||
"https://astrbot.app",
|
||||
)
|
||||
class Waiter(Star):
|
||||
"""会话控制"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context)
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
|
||||
async def handle_session_control_agent(self, event: AstrMessageEvent):
|
||||
session_id = event.unified_msg_origin
|
||||
if session_id in USER_SESSIONS:
|
||||
await SessionWaiter.trigger(session_id, event)
|
||||
event.stop_event()
|
||||
Reference in New Issue
Block a user