✨ feat: 事件钩子支持 yield 方式发送消息
This commit is contained in:
@@ -64,12 +64,14 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件。
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, req)
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, req)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -86,7 +88,9 @@ class LLMRequestSubStage(Stage):
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, llm_response)
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, llm_response)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
import asyncio
|
||||
import math
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
@@ -88,7 +89,11 @@ class RespondStage(Stage):
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
try:
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
event.clear_result()
|
||||
@@ -59,9 +59,16 @@ class ResultDecorateStage(Stage):
|
||||
async for _ in self.content_safe_check_stage.process(event, check_text=text):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
|
||||
for handler in handlers:
|
||||
await handler.handler(event)
|
||||
try:
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||
result = event.get_result()
|
||||
|
||||
@@ -36,16 +36,17 @@ class Stage(abc.ABC):
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
**params
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
'''调用 Handler。'''
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
try:
|
||||
ready_to_call = handler(event, **params)
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as e:
|
||||
# 向下兼容
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for ret in ready_to_call:
|
||||
|
||||
@@ -77,15 +77,11 @@ class WakingCheckStage(Stage):
|
||||
# 检查插件的 handler filter
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
|
||||
# 在输入指令组正确但是子指令错误的情况下提醒用户
|
||||
command_group_passed = False
|
||||
command_group_tree = None
|
||||
|
||||
if len(handler.event_filters) == 0:
|
||||
continue
|
||||
|
||||
@@ -94,12 +90,6 @@ class WakingCheckStage(Stage):
|
||||
if isinstance(filter, PermissionTypeFilter):
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
permission_not_pass = True
|
||||
elif isinstance(filter, CommandGroupFilter):
|
||||
if filter.filter(event, self.ctx.astrbot_config):
|
||||
command_group_passed = True
|
||||
command_group_tree = filter.print_cmd_tree(filter.sub_command_filters)
|
||||
passed = False
|
||||
break
|
||||
else:
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
passed = False
|
||||
@@ -128,14 +118,6 @@ class WakingCheckStage(Stage):
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
if not passed and command_group_passed:
|
||||
await event.send(
|
||||
MessageEventResult().message(
|
||||
f"插件 {star_map[handler.handler_module_path].name} 没有该指令。指令树:\n{command_group_tree}"
|
||||
)
|
||||
)
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
event.clear_extra()
|
||||
|
||||
|
||||
@@ -97,5 +97,6 @@ class CommandGroupFilter(HandlerFilter):
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
|
||||
|
||||
complete_command_names = [name + " " for name in complete_command_names]
|
||||
return event.message_str.startswith(tuple(complete_command_names))
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
# return event.message_str.startswith(tuple(complete_command_names))
|
||||
return False
|
||||
Reference in New Issue
Block a user