From 6d6fefc4355ce71cbe935449a6bc763f0cb8a83a Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:01:22 +0800 Subject: [PATCH] fix: anyio.ClosedResourceError when calling mcp tools (#3700) * fix: anyio.ClosedResourceError when calling mcp tools added reconnect mechanism fixes: 3676 * fix(mcp_client): implement thread-safe reconnection using asyncio.Lock --- astrbot/core/agent/mcp_client.py | 180 +++++++++++++++++---- astrbot/core/provider/func_tool_manager.py | 21 +-- pyproject.toml | 3 +- requirements.txt | 1 + 4 files changed, 168 insertions(+), 37 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 05980b21..88cab486 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,6 +4,14 @@ from contextlib import AsyncExitStack from datetime import timedelta from typing import Generic +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe @@ -12,21 +20,24 @@ from .run_context import TContext from .tool import FunctionTool try: + import anyio import mcp from mcp.client.sse import sse_client except (ModuleNotFoundError, ImportError): - logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") + logger.warning( + "Warning: Missing 'mcp' dependency, MCP services will be unavailable." + ) try: from mcp.client.streamable_http import streamablehttp_client except (ModuleNotFoundError, ImportError): logger.warning( - "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。", + "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.", ) def _prepare_config(config: dict) -> dict: - """准备配置,处理嵌套格式""" + """Prepare configuration, handle nested format""" if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] @@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict: async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: - """快速测试 MCP 服务器可达性""" + """Quick test MCP server connectivity""" import aiohttp cfg = _prepare_config(config.copy()) @@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: elif "type" in cfg: transport_type = cfg["type"] else: - raise Exception("MCP 连接配置缺少 transport 或 type 字段") + raise Exception("MCP connection config missing transport or type field") async with aiohttp.ClientSession() as session: if transport_type == "streamable_http": @@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"HTTP {response.status}: {response.reason}" except asyncio.TimeoutError: - return False, f"连接超时: {timeout}秒" + return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" @@ -101,6 +112,7 @@ class MCPClient: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() + self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup self.name: str | None = None self.active: bool = True @@ -108,22 +120,32 @@ class MCPClient: self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() - async def connect_to_server(self, mcp_server_config: dict, name: str): - """连接到 MCP 服务器 + # Store connection config for reconnection + self._mcp_server_config: dict | None = None + self._server_name: str | None = None + self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection + self._reconnecting: bool = False # For logging and debugging - 如果 `url` 参数存在: - 1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。 - 1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。 - 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。 + async def connect_to_server(self, mcp_server_config: dict, name: str): + """Connect to MCP server + + If `url` parameter exists: + 1. When transport is specified as `streamable_http`, use Streamable HTTP connection. + 2. When transport is specified as `sse`, use SSE connection. + 3. If not specified, default to SSE connection to MCP service. Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ + # Store config for reconnection + self._mcp_server_config = mcp_server_config + self._server_name = name + cfg = _prepare_config(mcp_server_config.copy()) def logging_callback(msg: str): - # 处理 MCP 服务的错误日志 + # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -137,7 +159,7 @@ class MCPClient: elif "type" in cfg: transport_type = cfg["type"] else: - raise Exception("MCP 连接配置缺少 transport 或 type 字段") + raise Exception("MCP connection config missing transport or type field") if transport_type != "streamable_http": # SSE transport method @@ -193,7 +215,7 @@ class MCPClient: ) def callback(msg: str): - # 处理 MCP 服务的错误日志 + # Handle MCP service error logs self.server_errlogs.append(msg) stdio_transport = await self.exit_stack.enter_async_context( @@ -222,10 +244,120 @@ class MCPClient: self.tools = response.tools return response + async def _reconnect(self) -> None: + """Reconnect to the MCP server using the stored configuration. + + Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments. + + Raises: + Exception: raised when reconnection fails + """ + async with self._reconnect_lock: + # Check if already reconnecting (useful for logging) + if self._reconnecting: + logger.debug( + f"MCP Client {self._server_name} is already reconnecting, skipping" + ) + return + + if not self._mcp_server_config or not self._server_name: + raise Exception("Cannot reconnect: missing connection configuration") + + self._reconnecting = True + try: + logger.info( + f"Attempting to reconnect to MCP server {self._server_name}..." + ) + + # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues) + if self.exit_stack: + self._old_exit_stacks.append(self.exit_stack) + + # Mark old session as invalid + self.session = None + + # Create new exit stack for new connection + self.exit_stack = AsyncExitStack() + + # Reconnect using stored config + await self.connect_to_server(self._mcp_server_config, self._server_name) + await self.list_tools_and_save() + + logger.info( + f"Successfully reconnected to MCP server {self._server_name}" + ) + except Exception as e: + logger.error( + f"Failed to reconnect to MCP server {self._server_name}: {e}" + ) + raise + finally: + self._reconnecting = False + + async def call_tool_with_reconnect( + self, + tool_name: str, + arguments: dict, + read_timeout_seconds: timedelta, + ) -> mcp.types.CallToolResult: + """Call MCP tool with automatic reconnection on failure, max 2 retries. + + Args: + tool_name: tool name + arguments: tool arguments + read_timeout_seconds: read timeout + + Returns: + MCP tool call result + + Raises: + ValueError: MCP session is not available + anyio.ClosedResourceError: raised after reconnection failure + """ + + @retry( + retry=retry_if_exception_type(anyio.ClosedResourceError), + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=1, min=1, max=3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _call_with_retry(): + if not self.session: + raise ValueError("MCP session is not available for MCP function tools.") + + try: + return await self.session.call_tool( + name=tool_name, + arguments=arguments, + read_timeout_seconds=read_timeout_seconds, + ) + except anyio.ClosedResourceError: + logger.warning( + f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + ) + # Attempt to reconnect + await self._reconnect() + # Reraise the exception to trigger tenacity retry + raise + + return await _call_with_retry() + async def cleanup(self): - """Clean up resources""" - await self.exit_stack.aclose() - self.running_event.set() # Set the running event to indicate cleanup is done + """Clean up resources including old exit stacks from reconnections""" + # Set running_event first to unblock any waiting tasks + self.running_event.set() + + # Close current exit stack + try: + await self.exit_stack.aclose() + except Exception as e: + logger.debug(f"Error closing current exit stack: {e}") + + # Don't close old exit stacks as they may be in different task contexts + # They will be garbage collected naturally + # Just clear the list to release references + self._old_exit_stacks.clear() class MCPTool(FunctionTool, Generic[TContext]): @@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]): async def call( self, context: ContextWrapper[TContext], **kwargs ) -> mcp.types.CallToolResult: - session = self.mcp_client.session - if not session: - raise ValueError("MCP session is not available for MCP function tools.") - res = await session.call_tool( - name=self.mcp_tool.name, + return await self.mcp_client.call_tool_with_reconnect( + tool_name=self.mcp_tool.name, arguments=kwargs, - read_timeout_seconds=timedelta( - seconds=context.tool_call_timeout, - ), + read_timeout_seconds=timedelta(seconds=context.tool_call_timeout), ) - return res diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7cdbeec0..8e04423e 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -280,19 +280,22 @@ class FunctionToolManager: async def _terminate_mcp_client(self, name: str) -> None: """关闭并清理MCP客户端""" if name in self.mcp_client_dict: + client = self.mcp_client_dict[name] try: # 关闭MCP连接 - await self.mcp_client_dict[name].cleanup() - self.mcp_client_dict.pop(name) + await client.cleanup() except Exception as e: logger.error(f"清空 MCP 客户端资源 {name}: {e}。") - # 移除关联的FuncTool - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] - logger.info(f"已关闭 MCP 服务 {name}") + finally: + # Remove client from dict after cleanup attempt (successful or not) + self.mcp_client_dict.pop(name, None) + # 移除关联的FuncTool + self.func_list = [ + f + for f in self.func_list + if not (isinstance(f, MCPTool) and f.mcp_server_name == name) + ] + logger.info(f"已关闭 MCP 服务 {name}") @staticmethod async def test_mcp_server_connection(config: dict) -> list[str]: diff --git a/pyproject.toml b/pyproject.toml index 576bc196..70758184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "jieba>=0.42.1", "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "xinference-client", + "tenacity>=9.1.2", ] [dependency-groups] @@ -107,4 +108,4 @@ exclude = ["dashboard", "node_modules", "dist", "data", "tests"] [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" diff --git a/requirements.txt b/requirements.txt index e8b3dee3..b5674119 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,4 @@ rank-bm25>=0.2.2 jieba>=0.42.1 markitdown-no-magika[docx,xls,xlsx]>=0.1.2 xinference-client +tenacity>=9.1.2 \ No newline at end of file