import asyncio import logging from contextlib import AsyncExitStack from datetime import timedelta from typing import Generic from tenacity import ( before_sleep_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe from .run_context import TContext from .tool import FunctionTool try: import anyio import mcp from mcp.client.sse import sse_client except (ModuleNotFoundError, ImportError): logger.warning( "Warning: Missing 'mcp' dependency, MCP services will be unavailable." ) try: from mcp.client.streamable_http import streamablehttp_client except (ModuleNotFoundError, ImportError): logger.warning( "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.", ) def _prepare_config(config: dict) -> dict: """Prepare configuration, handle nested format""" if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] config.pop("active", None) return config async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: """Quick test MCP server connectivity""" import aiohttp cfg = _prepare_config(config.copy()) url = cfg["url"] headers = cfg.get("headers", {}) timeout = cfg.get("timeout", 10) try: if "transport" in cfg: transport_type = cfg["transport"] elif "type" in cfg: transport_type = cfg["type"] else: raise Exception("MCP connection config missing transport or type field") async with aiohttp.ClientSession() as session: if transport_type == "streamable_http": test_payload = { "jsonrpc": "2.0", "method": "initialize", "id": 0, "params": { "protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test-client", "version": "1.2.3"}, }, } async with session.post( url, headers={ **headers, "Content-Type": "application/json", "Accept": "application/json, text/event-stream", }, json=test_payload, timeout=aiohttp.ClientTimeout(total=timeout), ) as response: if response.status == 200: return True, "" return False, f"HTTP {response.status}: {response.reason}" else: async with session.get( url, headers={ **headers, "Accept": "application/json, text/event-stream", }, timeout=aiohttp.ClientTimeout(total=timeout), ) as response: if response.status == 200: return True, "" return False, f"HTTP {response.status}: {response.reason}" except asyncio.TimeoutError: return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" class MCPClient: def __init__(self): # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup self.name: str | None = None self.active: bool = True self.tools: list[mcp.Tool] = [] self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() # Store connection config for reconnection self._mcp_server_config: dict | None = None self._server_name: str | None = None self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging async def connect_to_server(self, mcp_server_config: dict, name: str): """Connect to MCP server If `url` parameter exists: 1. When transport is specified as `streamable_http`, use Streamable HTTP connection. 2. When transport is specified as `sse`, use SSE connection. 3. If not specified, default to SSE connection to MCP service. Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ # Store config for reconnection self._mcp_server_config = mcp_server_config self._server_name = name cfg = _prepare_config(mcp_server_config.copy()) def logging_callback(msg: str): # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) if "url" in cfg: success, error_msg = await _quick_test_mcp_connection(cfg) if not success: raise Exception(error_msg) if "transport" in cfg: transport_type = cfg["transport"] elif "type" in cfg: transport_type = cfg["type"] else: raise Exception("MCP connection config missing transport or type field") if transport_type != "streamable_http": # SSE transport method self._streams_context = sse_client( url=cfg["url"], headers=cfg.get("headers", {}), timeout=cfg.get("timeout", 5), sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), ) streams = await self.exit_stack.enter_async_context( self._streams_context, ) # Create a new client session read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60)) self.session = await self.exit_stack.enter_async_context( mcp.ClientSession( *streams, read_timeout_seconds=read_timeout, logging_callback=logging_callback, # type: ignore ), ) else: timeout = timedelta(seconds=cfg.get("timeout", 30)) sse_read_timeout = timedelta( seconds=cfg.get("sse_read_timeout", 60 * 5), ) self._streams_context = streamablehttp_client( url=cfg["url"], headers=cfg.get("headers", {}), timeout=timeout, sse_read_timeout=sse_read_timeout, terminate_on_close=cfg.get("terminate_on_close", True), ) read_s, write_s, _ = await self.exit_stack.enter_async_context( self._streams_context, ) # Create a new client session read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60)) self.session = await self.exit_stack.enter_async_context( mcp.ClientSession( read_stream=read_s, write_stream=write_s, read_timeout_seconds=read_timeout, logging_callback=logging_callback, # type: ignore ), ) else: server_params = mcp.StdioServerParameters( **cfg, ) def callback(msg: str): # Handle MCP service error logs self.server_errlogs.append(msg) stdio_transport = await self.exit_stack.enter_async_context( mcp.stdio_client( server_params, errlog=LogPipe( level=logging.ERROR, logger=logger, identifier=f"MCPServer-{name}", callback=callback, ), # type: ignore ), ) # Create a new client session self.session = await self.exit_stack.enter_async_context( mcp.ClientSession(*stdio_transport), ) await self.session.initialize() async def list_tools_and_save(self) -> mcp.ListToolsResult: """List all tools from the server and save them to self.tools""" if not self.session: raise Exception("MCP Client is not initialized") response = await self.session.list_tools() self.tools = response.tools return response async def _reconnect(self) -> None: """Reconnect to the MCP server using the stored configuration. Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments. Raises: Exception: raised when reconnection fails """ async with self._reconnect_lock: # Check if already reconnecting (useful for logging) if self._reconnecting: logger.debug( f"MCP Client {self._server_name} is already reconnecting, skipping" ) return if not self._mcp_server_config or not self._server_name: raise Exception("Cannot reconnect: missing connection configuration") self._reconnecting = True try: logger.info( f"Attempting to reconnect to MCP server {self._server_name}..." ) # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues) if self.exit_stack: self._old_exit_stacks.append(self.exit_stack) # Mark old session as invalid self.session = None # Create new exit stack for new connection self.exit_stack = AsyncExitStack() # Reconnect using stored config await self.connect_to_server(self._mcp_server_config, self._server_name) await self.list_tools_and_save() logger.info( f"Successfully reconnected to MCP server {self._server_name}" ) except Exception as e: logger.error( f"Failed to reconnect to MCP server {self._server_name}: {e}" ) raise finally: self._reconnecting = False async def call_tool_with_reconnect( self, tool_name: str, arguments: dict, read_timeout_seconds: timedelta, ) -> mcp.types.CallToolResult: """Call MCP tool with automatic reconnection on failure, max 2 retries. Args: tool_name: tool name arguments: tool arguments read_timeout_seconds: read timeout Returns: MCP tool call result Raises: ValueError: MCP session is not available anyio.ClosedResourceError: raised after reconnection failure """ @retry( retry=retry_if_exception_type(anyio.ClosedResourceError), stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=1, max=3), before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) async def _call_with_retry(): if not self.session: raise ValueError("MCP session is not available for MCP function tools.") try: return await self.session.call_tool( name=tool_name, arguments=arguments, read_timeout_seconds=read_timeout_seconds, ) except anyio.ClosedResourceError: logger.warning( f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." ) # Attempt to reconnect await self._reconnect() # Reraise the exception to trigger tenacity retry raise return await _call_with_retry() async def cleanup(self): """Clean up resources including old exit stacks from reconnections""" # Close current exit stack try: await self.exit_stack.aclose() except Exception as e: logger.debug(f"Error closing current exit stack: {e}") # Don't close old exit stacks as they may be in different task contexts # They will be garbage collected naturally # Just clear the list to release references self._old_exit_stacks.clear() # Set running_event first to unblock any waiting tasks self.running_event.set() class MCPTool(FunctionTool, Generic[TContext]): """A function tool that calls an MCP service.""" def __init__( self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs ): super().__init__( name=mcp_tool.name, description=mcp_tool.description or "", parameters=mcp_tool.inputSchema, ) self.mcp_tool = mcp_tool self.mcp_client = mcp_client self.mcp_server_name = mcp_server_name async def call( self, context: ContextWrapper[TContext], **kwargs ) -> mcp.types.CallToolResult: return await self.mcp_client.call_tool_with_reconnect( tool_name=self.mcp_tool.name, arguments=kwargs, read_timeout_seconds=timedelta(seconds=context.tool_call_timeout), )