fix: anyio.ClosedResourceError when calling mcp tools (#3700)
* fix: anyio.ClosedResourceError when calling mcp tools added reconnect mechanism fixes: 3676 * fix(mcp_client): implement thread-safe reconnection using asyncio.Lock
This commit is contained in:
@@ -4,6 +4,14 @@ from contextlib import AsyncExitStack
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
|
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.agent.run_context import ContextWrapper
|
from astrbot.core.agent.run_context import ContextWrapper
|
||||||
from astrbot.core.utils.log_pipe import LogPipe
|
from astrbot.core.utils.log_pipe import LogPipe
|
||||||
@@ -12,21 +20,24 @@ from .run_context import TContext
|
|||||||
from .tool import FunctionTool
|
from .tool import FunctionTool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import anyio
|
||||||
import mcp
|
import mcp
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
logger.warning(
|
||||||
|
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
|
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_config(config: dict) -> dict:
|
def _prepare_config(config: dict) -> dict:
|
||||||
"""准备配置,处理嵌套格式"""
|
"""Prepare configuration, handle nested format"""
|
||||||
if config.get("mcpServers"):
|
if config.get("mcpServers"):
|
||||||
first_key = next(iter(config["mcpServers"]))
|
first_key = next(iter(config["mcpServers"]))
|
||||||
config = config["mcpServers"][first_key]
|
config = config["mcpServers"][first_key]
|
||||||
@@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||||
"""快速测试 MCP 服务器可达性"""
|
"""Quick test MCP server connectivity"""
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
cfg = _prepare_config(config.copy())
|
cfg = _prepare_config(config.copy())
|
||||||
@@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|||||||
elif "type" in cfg:
|
elif "type" in cfg:
|
||||||
transport_type = cfg["type"]
|
transport_type = cfg["type"]
|
||||||
else:
|
else:
|
||||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
raise Exception("MCP connection config missing transport or type field")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
if transport_type == "streamable_http":
|
if transport_type == "streamable_http":
|
||||||
@@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|||||||
return False, f"HTTP {response.status}: {response.reason}"
|
return False, f"HTTP {response.status}: {response.reason}"
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return False, f"连接超时: {timeout}秒"
|
return False, f"Connection timeout: {timeout} seconds"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, f"{e!s}"
|
return False, f"{e!s}"
|
||||||
|
|
||||||
@@ -101,6 +112,7 @@ class MCPClient:
|
|||||||
# Initialize session and client objects
|
# Initialize session and client objects
|
||||||
self.session: mcp.ClientSession | None = None
|
self.session: mcp.ClientSession | None = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
||||||
|
|
||||||
self.name: str | None = None
|
self.name: str | None = None
|
||||||
self.active: bool = True
|
self.active: bool = True
|
||||||
@@ -108,22 +120,32 @@ class MCPClient:
|
|||||||
self.server_errlogs: list[str] = []
|
self.server_errlogs: list[str] = []
|
||||||
self.running_event = asyncio.Event()
|
self.running_event = asyncio.Event()
|
||||||
|
|
||||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
# Store connection config for reconnection
|
||||||
"""连接到 MCP 服务器
|
self._mcp_server_config: dict | None = None
|
||||||
|
self._server_name: str | None = None
|
||||||
|
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
||||||
|
self._reconnecting: bool = False # For logging and debugging
|
||||||
|
|
||||||
如果 `url` 参数存在:
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
"""Connect to MCP server
|
||||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
|
||||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
If `url` parameter exists:
|
||||||
|
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
||||||
|
2. When transport is specified as `sse`, use SSE connection.
|
||||||
|
3. If not specified, default to SSE connection to MCP service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Store config for reconnection
|
||||||
|
self._mcp_server_config = mcp_server_config
|
||||||
|
self._server_name = name
|
||||||
|
|
||||||
cfg = _prepare_config(mcp_server_config.copy())
|
cfg = _prepare_config(mcp_server_config.copy())
|
||||||
|
|
||||||
def logging_callback(msg: str):
|
def logging_callback(msg: str):
|
||||||
# 处理 MCP 服务的错误日志
|
# Handle MCP service error logs
|
||||||
print(f"MCP Server {name} Error: {msg}")
|
print(f"MCP Server {name} Error: {msg}")
|
||||||
self.server_errlogs.append(msg)
|
self.server_errlogs.append(msg)
|
||||||
|
|
||||||
@@ -137,7 +159,7 @@ class MCPClient:
|
|||||||
elif "type" in cfg:
|
elif "type" in cfg:
|
||||||
transport_type = cfg["type"]
|
transport_type = cfg["type"]
|
||||||
else:
|
else:
|
||||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
raise Exception("MCP connection config missing transport or type field")
|
||||||
|
|
||||||
if transport_type != "streamable_http":
|
if transport_type != "streamable_http":
|
||||||
# SSE transport method
|
# SSE transport method
|
||||||
@@ -193,7 +215,7 @@ class MCPClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def callback(msg: str):
|
def callback(msg: str):
|
||||||
# 处理 MCP 服务的错误日志
|
# Handle MCP service error logs
|
||||||
self.server_errlogs.append(msg)
|
self.server_errlogs.append(msg)
|
||||||
|
|
||||||
stdio_transport = await self.exit_stack.enter_async_context(
|
stdio_transport = await self.exit_stack.enter_async_context(
|
||||||
@@ -222,10 +244,120 @@ class MCPClient:
|
|||||||
self.tools = response.tools
|
self.tools = response.tools
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def _reconnect(self) -> None:
|
||||||
|
"""Reconnect to the MCP server using the stored configuration.
|
||||||
|
|
||||||
|
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: raised when reconnection fails
|
||||||
|
"""
|
||||||
|
async with self._reconnect_lock:
|
||||||
|
# Check if already reconnecting (useful for logging)
|
||||||
|
if self._reconnecting:
|
||||||
|
logger.debug(
|
||||||
|
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._mcp_server_config or not self._server_name:
|
||||||
|
raise Exception("Cannot reconnect: missing connection configuration")
|
||||||
|
|
||||||
|
self._reconnecting = True
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Attempting to reconnect to MCP server {self._server_name}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
||||||
|
if self.exit_stack:
|
||||||
|
self._old_exit_stacks.append(self.exit_stack)
|
||||||
|
|
||||||
|
# Mark old session as invalid
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
# Create new exit stack for new connection
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
# Reconnect using stored config
|
||||||
|
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
||||||
|
await self.list_tools_and_save()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully reconnected to MCP server {self._server_name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._reconnecting = False
|
||||||
|
|
||||||
|
async def call_tool_with_reconnect(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict,
|
||||||
|
read_timeout_seconds: timedelta,
|
||||||
|
) -> mcp.types.CallToolResult:
|
||||||
|
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: tool name
|
||||||
|
arguments: tool arguments
|
||||||
|
read_timeout_seconds: read timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCP tool call result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: MCP session is not available
|
||||||
|
anyio.ClosedResourceError: raised after reconnection failure
|
||||||
|
"""
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||||
|
stop=stop_after_attempt(2),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def _call_with_retry():
|
||||||
|
if not self.session:
|
||||||
|
raise ValueError("MCP session is not available for MCP function tools.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.session.call_tool(
|
||||||
|
name=tool_name,
|
||||||
|
arguments=arguments,
|
||||||
|
read_timeout_seconds=read_timeout_seconds,
|
||||||
|
)
|
||||||
|
except anyio.ClosedResourceError:
|
||||||
|
logger.warning(
|
||||||
|
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
||||||
|
)
|
||||||
|
# Attempt to reconnect
|
||||||
|
await self._reconnect()
|
||||||
|
# Reraise the exception to trigger tenacity retry
|
||||||
|
raise
|
||||||
|
|
||||||
|
return await _call_with_retry()
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""Clean up resources"""
|
"""Clean up resources including old exit stacks from reconnections"""
|
||||||
await self.exit_stack.aclose()
|
# Set running_event first to unblock any waiting tasks
|
||||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
self.running_event.set()
|
||||||
|
|
||||||
|
# Close current exit stack
|
||||||
|
try:
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing current exit stack: {e}")
|
||||||
|
|
||||||
|
# Don't close old exit stacks as they may be in different task contexts
|
||||||
|
# They will be garbage collected naturally
|
||||||
|
# Just clear the list to release references
|
||||||
|
self._old_exit_stacks.clear()
|
||||||
|
|
||||||
|
|
||||||
class MCPTool(FunctionTool, Generic[TContext]):
|
class MCPTool(FunctionTool, Generic[TContext]):
|
||||||
@@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]):
|
|||||||
async def call(
|
async def call(
|
||||||
self, context: ContextWrapper[TContext], **kwargs
|
self, context: ContextWrapper[TContext], **kwargs
|
||||||
) -> mcp.types.CallToolResult:
|
) -> mcp.types.CallToolResult:
|
||||||
session = self.mcp_client.session
|
return await self.mcp_client.call_tool_with_reconnect(
|
||||||
if not session:
|
tool_name=self.mcp_tool.name,
|
||||||
raise ValueError("MCP session is not available for MCP function tools.")
|
|
||||||
res = await session.call_tool(
|
|
||||||
name=self.mcp_tool.name,
|
|
||||||
arguments=kwargs,
|
arguments=kwargs,
|
||||||
read_timeout_seconds=timedelta(
|
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
|
||||||
seconds=context.tool_call_timeout,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|||||||
@@ -280,19 +280,22 @@ class FunctionToolManager:
|
|||||||
async def _terminate_mcp_client(self, name: str) -> None:
|
async def _terminate_mcp_client(self, name: str) -> None:
|
||||||
"""关闭并清理MCP客户端"""
|
"""关闭并清理MCP客户端"""
|
||||||
if name in self.mcp_client_dict:
|
if name in self.mcp_client_dict:
|
||||||
|
client = self.mcp_client_dict[name]
|
||||||
try:
|
try:
|
||||||
# 关闭MCP连接
|
# 关闭MCP连接
|
||||||
await self.mcp_client_dict[name].cleanup()
|
await client.cleanup()
|
||||||
self.mcp_client_dict.pop(name)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||||
# 移除关联的FuncTool
|
finally:
|
||||||
self.func_list = [
|
# Remove client from dict after cleanup attempt (successful or not)
|
||||||
f
|
self.mcp_client_dict.pop(name, None)
|
||||||
for f in self.func_list
|
# 移除关联的FuncTool
|
||||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
self.func_list = [
|
||||||
]
|
f
|
||||||
logger.info(f"已关闭 MCP 服务 {name}")
|
for f in self.func_list
|
||||||
|
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
logger.info(f"已关闭 MCP 服务 {name}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ dependencies = [
|
|||||||
"jieba>=0.42.1",
|
"jieba>=0.42.1",
|
||||||
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
||||||
"xinference-client",
|
"xinference-client",
|
||||||
|
"tenacity>=9.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
@@ -107,4 +108,4 @@ exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
|||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|||||||
@@ -52,3 +52,4 @@ rank-bm25>=0.2.2
|
|||||||
jieba>=0.42.1
|
jieba>=0.42.1
|
||||||
markitdown-no-magika[docx,xls,xlsx]>=0.1.2
|
markitdown-no-magika[docx,xls,xlsx]>=0.1.2
|
||||||
xinference-client
|
xinference-client
|
||||||
|
tenacity>=9.1.2
|
||||||
Reference in New Issue
Block a user