Files
AstrBot/astrbot/core/provider/func_tool_manager.py
Soulter 6d6fefc435 fix: anyio.ClosedResourceError when calling mcp tools (#3700)
* fix: anyio.ClosedResourceError when calling mcp tools

added reconnect mechanism

fixes: 3676

* fix(mcp_client): implement thread-safe reconnection using asyncio.Lock
2025-11-20 16:01:22 +08:00

587 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
from typing import Any
import aiohttp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
PY_TO_JSON_TYPE = {
"int": "number",
"float": "number",
"bool": "boolean",
"str": "string",
"dict": "object",
"list": "array",
"tuple": "array",
"set": "array",
}
# alias
FuncTool = FunctionTool
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
class FunctionToolManager:
def __init__(self) -> None:
self.func_list: list[FuncTool] = []
self.mcp_client_dict: dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: dict[str, asyncio.Event] = {}
def empty(self) -> bool:
return len(self.func_list) == 0
def spec_to_func(
self,
name: str,
func_args: list[dict],
desc: str,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
p = copy.deepcopy(param)
p.pop("name", None)
params["properties"][param["name"]] = p
return FuncTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""添加函数调用工具
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
"""
# check if the tool has been added before
self.remove_func(name)
self.func_list.append(
self.spec_to_func(
name=name,
func_args=func_args,
desc=desc,
handler=handler,
),
)
logger.info(f"添加函数调用工具: {name}")
def remove_func(self, name: str) -> None:
"""删除一个函数调用工具。"""
for i, f in enumerate(self.func_list):
if f.name == name:
self.func_list.pop(i)
break
def get_func(self, name) -> FuncTool | None:
for f in self.func_list:
if f.name == name:
return f
def get_full_tool_set(self) -> ToolSet:
"""获取完整工具集"""
tool_set = ToolSet(self.func_list.copy())
return tool_set
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
"mcpServers": {
"weather": {
"command": "uv",
"args": [
"--directory",
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
"run",
"weather.py"
]
}
}
...
}
```
"""
data_dir = get_astrbot_data_path()
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return
mcp_server_json_obj: dict[str, dict] = json.load(
open(mcp_json_file, encoding="utf-8"),
)["mcpServers"]
for name in mcp_server_json_obj:
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
)
self.mcp_client_event[name] = event
async def _init_mcp_client_task_wrapper(
self,
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
ready_future.set_exception(e)
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
mcp_client = MCPClient()
mcp_client.name = name
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]
# 移除该MCP服务之前的工具如有
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = MCPTool(
mcp_tool=tool,
mcp_client=mcp_client,
mcp_server_name=name,
)
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
client = self.mcp_client_dict[name]
try:
# 关闭MCP连接
await client.cleanup()
except Exception as e:
logger.error(f"清空 MCP 客户端资源 {name}: {e}")
finally:
# Remove client from dict after cleanup attempt (successful or not)
self.mcp_client_dict.pop(name, None)
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
if "url" in config:
success, error_msg = await _quick_test_mcp_connection(config)
if not success:
raise Exception(error_msg)
mcp_client = MCPClient()
try:
logger.debug(f"testing MCP server connection with config: {config}")
await mcp_client.connect_to_server(config, "test")
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
finally:
logger.debug("Cleaning up MCP client after testing connection.")
await mcp_client.cleanup()
return tool_names
async def enable_mcp_server(
self,
name: str,
config: dict,
event: asyncio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
) -> None:
"""Enable_mcp_server a new MCP server to the manager and initialize it.
Args:
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.
Raises:
TimeoutError: If the initialization does not complete within the specified timeout.
Exception: If there is an error during initialization.
"""
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self,
name: str | None = None,
timeout: float = 10,
) -> None:
"""Disable an MCP server by its name.
Args:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout (int): Timeout.
"""
if name:
if name not in self.mcp_client_event:
return
client = self.mcp_client_dict.get(name)
self.mcp_client_event[name].set()
if not client:
return
client_running_event = client.running_event
try:
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
finally:
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
else:
running_events = [
client.running_event.wait() for client in self.mcp_client_dict.values()
]
for key, event in self.mcp_client_event.items():
event.set()
# waiting for all clients to finish
try:
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [
f for f in self.func_list if not isinstance(f, MCPTool)
]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.openai_schema(
omit_empty_parameter_field=omit_empty_parameter_field,
)
def get_func_desc_anthropic_style(self) -> list:
"""获得 Anthropic API 风格的**已经激活**的工具描述"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.anthropic_schema()
def get_func_desc_google_genai_style(self) -> dict:
"""获得 Google GenAI API 风格的**已经激活**的工具描述"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.google_schema()
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False
"""
func_tool = self.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools",
[],
scope="global",
scope_id="global",
)
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
return True
return False
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
func_tool = self.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。",
)
func_tool.active = True
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools",
[],
scope="global",
scope_id="global",
)
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
return True
return False
@property
def mcp_config_path(self):
data_dir = get_astrbot_data_path()
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config(self):
if not os.path.exists(self.mcp_config_path):
# 配置文件不存在,创建默认配置
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
try:
with open(self.mcp_config_path, encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载 MCP 配置失败: {e}")
return DEFAULT_MCP_CONFIG
def save_mcp_config(self, config: dict):
try:
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception as e:
logger.error(f"保存 MCP 配置失败: {e}")
return False
async def sync_modelscope_mcp_servers(self, access_token: str) -> None:
"""从 ModelScope 平台同步 MCP 服务器配置"""
base_url = "https://www.modelscope.cn/openapi/v1"
url = f"{base_url}/mcp/servers/operational"
headers = {
"Authorization": f"Bearer {access_token.strip()}",
"Content-Type": "application/json",
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
mcp_server_list = data.get("data", {}).get(
"mcp_server_list",
[],
)
local_mcp_config = self.load_mcp_config()
synced_count = 0
for server in mcp_server_list:
server_name = server["name"]
operational_urls = server.get("operational_urls", [])
if not operational_urls:
continue
url_info = operational_urls[0]
server_url = url_info.get("url")
if not server_url:
continue
# 添加到配置中(同名会覆盖)
local_mcp_config["mcpServers"][server_name] = {
"url": server_url,
"transport": "sse",
"active": True,
"provider": "modelscope",
}
synced_count += 1
if synced_count > 0:
self.save_mcp_config(local_mcp_config)
tasks = []
for server in mcp_server_list:
name = server["name"]
tasks.append(
self.enable_mcp_server(
name=name,
config=local_mcp_config["mcpServers"][name],
),
)
await asyncio.gather(*tasks)
logger.info(
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器",
)
else:
logger.warning("没有找到可用的 ModelScope MCP 服务器")
else:
raise Exception(
f"ModelScope API 请求失败: HTTP {response.status}",
)
except aiohttp.ClientError as e:
raise Exception(f"网络连接错误: {e!s}")
except Exception as e:
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}")
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
# alias
FuncCall = FunctionToolManager