feat: 事件钩子支持 yield 方式发送消息

This commit is contained in:
Soulter
2025-02-19 15:29:10 +08:00
parent 4678222e9b
commit 782c0367d0
6 changed files with 30 additions and 30 deletions

View File

@@ -64,12 +64,14 @@ class LLMRequestSubStage(Stage):
if not req.prompt and not req.image_urls:
return
# 执行请求 LLM 前事件。
# 执行请求 LLM 前事件钩子
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
await handler.handler(event, req)
wrapper = self._call_handler(self.ctx, event, handler.handler, req)
async for ret in wrapper:
yield ret
except BaseException:
logger.error(traceback.format_exc())
@@ -86,7 +88,9 @@ class LLMRequestSubStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
await handler.handler(event, llm_response)
wrapper = self._call_handler(self.ctx, event, handler.handler, llm_response)
async for ret in wrapper:
yield ret
except BaseException:
logger.error(traceback.format_exc())

View File

@@ -1,6 +1,7 @@
import random
import asyncio
import math
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
@@ -88,7 +89,11 @@ class RespondStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
try:
wrapper = self._call_handler(self.ctx, event, handler.handler)
async for ret in wrapper:
yield ret
except BaseException:
logger.error(traceback.format_exc())
event.clear_result()

View File

@@ -59,9 +59,16 @@ class ResultDecorateStage(Stage):
async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
await handler.handler(event)
try:
wrapper = self._call_handler(self.ctx, event, handler.handler)
async for ret in wrapper:
yield ret
except BaseException:
logger.error(traceback.format_exc())
# 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result()

View File

@@ -36,16 +36,17 @@ class Stage(abc.ABC):
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
**params
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
'''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
try:
ready_to_call = handler(event, **params)
ready_to_call = handler(event, *args, **kwargs)
except TypeError as e:
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
async for ret in ready_to_call:

View File

@@ -77,15 +77,11 @@ class WakingCheckStage(Stage):
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
# 在输入指令组正确但是子指令错误的情况下提醒用户
command_group_passed = False
command_group_tree = None
if len(handler.event_filters) == 0:
continue
@@ -94,12 +90,6 @@ class WakingCheckStage(Stage):
if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
elif isinstance(filter, CommandGroupFilter):
if filter.filter(event, self.ctx.astrbot_config):
command_group_passed = True
command_group_tree = filter.print_cmd_tree(filter.sub_command_filters)
passed = False
break
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -128,14 +118,6 @@ class WakingCheckStage(Stage):
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
if not passed and command_group_passed:
await event.send(
MessageEventResult().message(
f"插件 {star_map[handler.handler_module_path].name} 没有该指令。指令树:\n{command_group_tree}"
)
)
event.stop_event()
return
event.clear_extra()

View File

@@ -97,5 +97,6 @@ class CommandGroupFilter(HandlerFilter):
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
complete_command_names = [name + " " for name in complete_command_names]
return event.message_str.startswith(tuple(complete_command_names))
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False