247 lines
8.6 KiB
Python
247 lines
8.6 KiB
Python
import asyncio
|
||
import inspect
|
||
import traceback
|
||
import typing as T
|
||
|
||
import mcp
|
||
|
||
from astrbot import logger
|
||
from astrbot.core.agent.handoff import HandoffTool
|
||
from astrbot.core.agent.mcp_client import MCPTool
|
||
from astrbot.core.agent.run_context import ContextWrapper
|
||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||
from astrbot.core.message.message_event_result import (
|
||
CommandResult,
|
||
MessageChain,
|
||
MessageEventResult,
|
||
)
|
||
from astrbot.core.provider.register import llm_tools
|
||
|
||
|
||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||
@classmethod
|
||
async def execute(cls, tool, run_context, **tool_args):
|
||
"""执行函数调用。
|
||
|
||
Args:
|
||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||
**kwargs: 函数调用的参数。
|
||
|
||
Returns:
|
||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||
|
||
"""
|
||
if isinstance(tool, HandoffTool):
|
||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||
yield r
|
||
return
|
||
|
||
elif isinstance(tool, MCPTool):
|
||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||
yield r
|
||
return
|
||
|
||
else:
|
||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||
yield r
|
||
return
|
||
|
||
@classmethod
|
||
async def _execute_handoff(
|
||
cls,
|
||
tool: HandoffTool,
|
||
run_context: ContextWrapper[AstrAgentContext],
|
||
**tool_args,
|
||
):
|
||
input_ = tool_args.get("input")
|
||
|
||
# make toolset for the agent
|
||
tools = tool.agent.tools
|
||
if tools:
|
||
toolset = ToolSet()
|
||
for t in tools:
|
||
if isinstance(t, str):
|
||
_t = llm_tools.get_func(t)
|
||
if _t:
|
||
toolset.add_tool(_t)
|
||
elif isinstance(t, FunctionTool):
|
||
toolset.add_tool(t)
|
||
else:
|
||
toolset = None
|
||
|
||
ctx = run_context.context.context
|
||
event = run_context.context.event
|
||
umo = event.unified_msg_origin
|
||
prov_id = await ctx.get_current_chat_provider_id(umo)
|
||
llm_resp = await ctx.tool_loop_agent(
|
||
event=event,
|
||
chat_provider_id=prov_id,
|
||
prompt=input_,
|
||
system_prompt=tool.agent.instructions,
|
||
tools=toolset,
|
||
max_steps=30,
|
||
run_hooks=tool.agent.run_hooks,
|
||
)
|
||
yield mcp.types.CallToolResult(
|
||
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
||
)
|
||
|
||
@classmethod
|
||
async def _execute_local(
|
||
cls,
|
||
tool: FunctionTool,
|
||
run_context: ContextWrapper[AstrAgentContext],
|
||
**tool_args,
|
||
):
|
||
event = run_context.context.event
|
||
if not event:
|
||
raise ValueError("Event must be provided for local function tools.")
|
||
|
||
is_override_call = False
|
||
for ty in type(tool).mro():
|
||
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||
is_override_call = True
|
||
break
|
||
|
||
# 检查 tool 下有没有 run 方法
|
||
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||
|
||
awaitable = None
|
||
method_name = ""
|
||
if tool.handler:
|
||
awaitable = tool.handler
|
||
method_name = "decorator_handler"
|
||
elif is_override_call:
|
||
awaitable = tool.call
|
||
method_name = "call"
|
||
elif hasattr(tool, "run"):
|
||
awaitable = getattr(tool, "run")
|
||
method_name = "run"
|
||
if awaitable is None:
|
||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||
|
||
wrapper = call_local_llm_tool(
|
||
context=run_context,
|
||
handler=awaitable,
|
||
method_name=method_name,
|
||
**tool_args,
|
||
)
|
||
while True:
|
||
try:
|
||
resp = await asyncio.wait_for(
|
||
anext(wrapper),
|
||
timeout=run_context.tool_call_timeout,
|
||
)
|
||
if resp is not None:
|
||
if isinstance(resp, mcp.types.CallToolResult):
|
||
yield resp
|
||
else:
|
||
text_content = mcp.types.TextContent(
|
||
type="text",
|
||
text=str(resp),
|
||
)
|
||
yield mcp.types.CallToolResult(content=[text_content])
|
||
else:
|
||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||
if res := run_context.context.event.get_result():
|
||
if res.chain:
|
||
try:
|
||
await event.send(
|
||
MessageChain(
|
||
chain=res.chain,
|
||
type="tool_direct_result",
|
||
)
|
||
)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Tool 直接发送消息失败: {e}",
|
||
exc_info=True,
|
||
)
|
||
yield None
|
||
except asyncio.TimeoutError:
|
||
raise Exception(
|
||
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||
)
|
||
except StopAsyncIteration:
|
||
break
|
||
|
||
@classmethod
|
||
async def _execute_mcp(
|
||
cls,
|
||
tool: FunctionTool,
|
||
run_context: ContextWrapper[AstrAgentContext],
|
||
**tool_args,
|
||
):
|
||
res = await tool.call(run_context, **tool_args)
|
||
if not res:
|
||
return
|
||
yield res
|
||
|
||
|
||
async def call_local_llm_tool(
|
||
context: ContextWrapper[AstrAgentContext],
|
||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||
method_name: str,
|
||
*args,
|
||
**kwargs,
|
||
) -> T.AsyncGenerator[T.Any, None]:
|
||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||
ready_to_call = None # 一个协程或者异步生成器
|
||
|
||
trace_ = None
|
||
|
||
event = context.context.event
|
||
|
||
try:
|
||
if method_name == "run" or method_name == "decorator_handler":
|
||
ready_to_call = handler(event, *args, **kwargs)
|
||
elif method_name == "call":
|
||
ready_to_call = handler(context, *args, **kwargs)
|
||
else:
|
||
raise ValueError(f"未知的方法名: {method_name}")
|
||
except ValueError as e:
|
||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||
except TypeError:
|
||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||
except Exception as e:
|
||
trace_ = traceback.format_exc()
|
||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||
|
||
if not ready_to_call:
|
||
return
|
||
|
||
if inspect.isasyncgen(ready_to_call):
|
||
_has_yielded = False
|
||
try:
|
||
async for ret in ready_to_call:
|
||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||
_has_yielded = True
|
||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||
event.set_result(ret)
|
||
yield
|
||
else:
|
||
# 如果返回值是 None, 则不设置结果并继续
|
||
# 继续执行后续阶段
|
||
yield ret
|
||
if not _has_yielded:
|
||
# 如果这个异步生成器没有执行到 yield 分支
|
||
yield
|
||
except Exception as e:
|
||
logger.error(f"Previous Error: {trace_}")
|
||
raise e
|
||
elif inspect.iscoroutine(ready_to_call):
|
||
# 如果只是一个协程, 直接执行
|
||
ret = await ready_to_call
|
||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||
event.set_result(ret)
|
||
yield
|
||
else:
|
||
yield ret
|