* refactor: llm message schema * feat: implement MCPTool and local LLM tools with enhanced context handling * refactor: reorganize imports and enhance docstrings for clarity * refactor: enhance ContentPart validation and add message pair handling in ConversationManager * chore: ruff format * refactor: remove debug print statement from payloads in ProviderOpenAIOfficial * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/pipeline/process_stage/method/llm_request.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: enhance documentation and import mcp in tool.py; update call method return type * fix: 修复以数据类的方式注册 tool 时的插件重载机制问题 * refactor: change role attributes to use Literal types for message segments * fix: add support for 'decorator_handler' method in call_local_llm_tool * fix: handle None prompt in text_chat method and ensure context is properly formatted --------- Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
173 lines
5.9 KiB
Python
173 lines
5.9 KiB
Python
import inspect
|
||
import traceback
|
||
import typing as T
|
||
|
||
from astrbot import logger
|
||
from astrbot.core.agent.run_context import ContextWrapper
|
||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
|
||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||
from astrbot.core.star.star import star_map
|
||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||
|
||
|
||
async def call_handler(
|
||
event: AstrMessageEvent,
|
||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||
*args,
|
||
**kwargs,
|
||
) -> T.AsyncGenerator[T.Any, None]:
|
||
"""执行事件处理函数并处理其返回结果
|
||
|
||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||
2. 协程: 执行一次并处理返回值
|
||
|
||
Args:
|
||
event (AstrMessageEvent): 事件对象
|
||
handler (Awaitable): 事件处理函数
|
||
|
||
Returns:
|
||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||
|
||
"""
|
||
ready_to_call = None # 一个协程或者异步生成器
|
||
|
||
trace_ = None
|
||
|
||
try:
|
||
ready_to_call = handler(event, *args, **kwargs)
|
||
except TypeError:
|
||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||
|
||
if not ready_to_call:
|
||
return
|
||
|
||
if inspect.isasyncgen(ready_to_call):
|
||
_has_yielded = False
|
||
try:
|
||
async for ret in ready_to_call:
|
||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||
_has_yielded = True
|
||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||
event.set_result(ret)
|
||
yield
|
||
else:
|
||
# 如果返回值是 None, 则不设置结果并继续
|
||
# 继续执行后续阶段
|
||
yield ret
|
||
if not _has_yielded:
|
||
# 如果这个异步生成器没有执行到 yield 分支
|
||
yield
|
||
except Exception as e:
|
||
logger.error(f"Previous Error: {trace_}")
|
||
raise e
|
||
elif inspect.iscoroutine(ready_to_call):
|
||
# 如果只是一个协程, 直接执行
|
||
ret = await ready_to_call
|
||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||
event.set_result(ret)
|
||
yield
|
||
else:
|
||
yield ret
|
||
|
||
|
||
async def call_event_hook(
|
||
event: AstrMessageEvent,
|
||
hook_type: EventType,
|
||
*args,
|
||
**kwargs,
|
||
) -> bool:
|
||
"""调用事件钩子函数
|
||
|
||
Returns:
|
||
bool: 如果事件被终止,返回 True
|
||
#
|
||
|
||
"""
|
||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||
hook_type,
|
||
plugins_name=event.plugins_name,
|
||
)
|
||
for handler in handlers:
|
||
try:
|
||
logger.debug(
|
||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||
)
|
||
await handler.handler(event, *args, **kwargs)
|
||
except BaseException:
|
||
logger.error(traceback.format_exc())
|
||
|
||
if event.is_stopped():
|
||
logger.info(
|
||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。",
|
||
)
|
||
return True
|
||
|
||
return event.is_stopped()
|
||
|
||
|
||
async def call_local_llm_tool(
|
||
context: ContextWrapper[AstrAgentContext],
|
||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||
method_name: str,
|
||
*args,
|
||
**kwargs,
|
||
) -> T.AsyncGenerator[T.Any, None]:
|
||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||
ready_to_call = None # 一个协程或者异步生成器
|
||
|
||
trace_ = None
|
||
|
||
event = context.context.event
|
||
|
||
try:
|
||
if method_name == "run" or method_name == "decorator_handler":
|
||
ready_to_call = handler(event, *args, **kwargs)
|
||
elif method_name == "call":
|
||
ready_to_call = handler(context, *args, **kwargs)
|
||
else:
|
||
raise ValueError(f"未知的方法名: {method_name}")
|
||
except ValueError as e:
|
||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||
except TypeError:
|
||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||
except Exception as e:
|
||
trace_ = traceback.format_exc()
|
||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||
|
||
if not ready_to_call:
|
||
return
|
||
|
||
if inspect.isasyncgen(ready_to_call):
|
||
_has_yielded = False
|
||
try:
|
||
async for ret in ready_to_call:
|
||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||
_has_yielded = True
|
||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||
event.set_result(ret)
|
||
yield
|
||
else:
|
||
# 如果返回值是 None, 则不设置结果并继续
|
||
# 继续执行后续阶段
|
||
yield ret
|
||
if not _has_yielded:
|
||
# 如果这个异步生成器没有执行到 yield 分支
|
||
yield
|
||
except Exception as e:
|
||
logger.error(f"Previous Error: {trace_}")
|
||
raise e
|
||
elif inspect.iscoroutine(ready_to_call):
|
||
# 如果只是一个协程, 直接执行
|
||
ret = await ready_to_call
|
||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||
event.set_result(ret)
|
||
yield
|
||
else:
|
||
yield ret
|