|
|
|
|
@@ -4,6 +4,14 @@ from contextlib import AsyncExitStack
|
|
|
|
|
from datetime import timedelta
|
|
|
|
|
from typing import Generic
|
|
|
|
|
|
|
|
|
|
from tenacity import (
|
|
|
|
|
before_sleep_log,
|
|
|
|
|
retry,
|
|
|
|
|
retry_if_exception_type,
|
|
|
|
|
stop_after_attempt,
|
|
|
|
|
wait_exponential,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from astrbot import logger
|
|
|
|
|
from astrbot.core.agent.run_context import ContextWrapper
|
|
|
|
|
from astrbot.core.utils.log_pipe import LogPipe
|
|
|
|
|
@@ -12,21 +20,24 @@ from .run_context import TContext
|
|
|
|
|
from .tool import FunctionTool
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import anyio
|
|
|
|
|
import mcp
|
|
|
|
|
from mcp.client.sse import sse_client
|
|
|
|
|
except (ModuleNotFoundError, ImportError):
|
|
|
|
|
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
|
except (ModuleNotFoundError, ImportError):
|
|
|
|
|
logger.warning(
|
|
|
|
|
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
|
|
|
|
|
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_config(config: dict) -> dict:
|
|
|
|
|
"""准备配置,处理嵌套格式"""
|
|
|
|
|
"""Prepare configuration, handle nested format"""
|
|
|
|
|
if config.get("mcpServers"):
|
|
|
|
|
first_key = next(iter(config["mcpServers"]))
|
|
|
|
|
config = config["mcpServers"][first_key]
|
|
|
|
|
@@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
|
|
"""快速测试 MCP 服务器可达性"""
|
|
|
|
|
"""Quick test MCP server connectivity"""
|
|
|
|
|
import aiohttp
|
|
|
|
|
|
|
|
|
|
cfg = _prepare_config(config.copy())
|
|
|
|
|
@@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
|
|
elif "type" in cfg:
|
|
|
|
|
transport_type = cfg["type"]
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
|
|
|
|
raise Exception("MCP connection config missing transport or type field")
|
|
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
if transport_type == "streamable_http":
|
|
|
|
|
@@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
|
|
return False, f"HTTP {response.status}: {response.reason}"
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
return False, f"连接超时: {timeout}秒"
|
|
|
|
|
return False, f"Connection timeout: {timeout} seconds"
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return False, f"{e!s}"
|
|
|
|
|
|
|
|
|
|
@@ -101,6 +112,7 @@ class MCPClient:
|
|
|
|
|
# Initialize session and client objects
|
|
|
|
|
self.session: mcp.ClientSession | None = None
|
|
|
|
|
self.exit_stack = AsyncExitStack()
|
|
|
|
|
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
|
|
|
|
|
|
|
|
|
self.name: str | None = None
|
|
|
|
|
self.active: bool = True
|
|
|
|
|
@@ -108,22 +120,32 @@ class MCPClient:
|
|
|
|
|
self.server_errlogs: list[str] = []
|
|
|
|
|
self.running_event = asyncio.Event()
|
|
|
|
|
|
|
|
|
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
|
|
|
|
"""连接到 MCP 服务器
|
|
|
|
|
# Store connection config for reconnection
|
|
|
|
|
self._mcp_server_config: dict | None = None
|
|
|
|
|
self._server_name: str | None = None
|
|
|
|
|
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
|
|
|
|
self._reconnecting: bool = False # For logging and debugging
|
|
|
|
|
|
|
|
|
|
如果 `url` 参数存在:
|
|
|
|
|
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
|
|
|
|
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
|
|
|
|
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
|
|
|
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
|
|
|
|
"""Connect to MCP server
|
|
|
|
|
|
|
|
|
|
If `url` parameter exists:
|
|
|
|
|
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
|
|
|
|
2. When transport is specified as `sse`, use SSE connection.
|
|
|
|
|
3. If not specified, default to SSE connection to MCP service.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# Store config for reconnection
|
|
|
|
|
self._mcp_server_config = mcp_server_config
|
|
|
|
|
self._server_name = name
|
|
|
|
|
|
|
|
|
|
cfg = _prepare_config(mcp_server_config.copy())
|
|
|
|
|
|
|
|
|
|
def logging_callback(msg: str):
|
|
|
|
|
# 处理 MCP 服务的错误日志
|
|
|
|
|
# Handle MCP service error logs
|
|
|
|
|
print(f"MCP Server {name} Error: {msg}")
|
|
|
|
|
self.server_errlogs.append(msg)
|
|
|
|
|
|
|
|
|
|
@@ -137,7 +159,7 @@ class MCPClient:
|
|
|
|
|
elif "type" in cfg:
|
|
|
|
|
transport_type = cfg["type"]
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
|
|
|
|
raise Exception("MCP connection config missing transport or type field")
|
|
|
|
|
|
|
|
|
|
if transport_type != "streamable_http":
|
|
|
|
|
# SSE transport method
|
|
|
|
|
@@ -193,7 +215,7 @@ class MCPClient:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def callback(msg: str):
|
|
|
|
|
# 处理 MCP 服务的错误日志
|
|
|
|
|
# Handle MCP service error logs
|
|
|
|
|
self.server_errlogs.append(msg)
|
|
|
|
|
|
|
|
|
|
stdio_transport = await self.exit_stack.enter_async_context(
|
|
|
|
|
@@ -222,10 +244,120 @@ class MCPClient:
|
|
|
|
|
self.tools = response.tools
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
async def _reconnect(self) -> None:
|
|
|
|
|
"""Reconnect to the MCP server using the stored configuration.
|
|
|
|
|
|
|
|
|
|
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: raised when reconnection fails
|
|
|
|
|
"""
|
|
|
|
|
async with self._reconnect_lock:
|
|
|
|
|
# Check if already reconnecting (useful for logging)
|
|
|
|
|
if self._reconnecting:
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self._mcp_server_config or not self._server_name:
|
|
|
|
|
raise Exception("Cannot reconnect: missing connection configuration")
|
|
|
|
|
|
|
|
|
|
self._reconnecting = True
|
|
|
|
|
try:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Attempting to reconnect to MCP server {self._server_name}..."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
|
|
|
|
if self.exit_stack:
|
|
|
|
|
self._old_exit_stacks.append(self.exit_stack)
|
|
|
|
|
|
|
|
|
|
# Mark old session as invalid
|
|
|
|
|
self.session = None
|
|
|
|
|
|
|
|
|
|
# Create new exit stack for new connection
|
|
|
|
|
self.exit_stack = AsyncExitStack()
|
|
|
|
|
|
|
|
|
|
# Reconnect using stored config
|
|
|
|
|
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
|
|
|
|
await self.list_tools_and_save()
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Successfully reconnected to MCP server {self._server_name}"
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(
|
|
|
|
|
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
|
|
|
|
)
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
self._reconnecting = False
|
|
|
|
|
|
|
|
|
|
async def call_tool_with_reconnect(
|
|
|
|
|
self,
|
|
|
|
|
tool_name: str,
|
|
|
|
|
arguments: dict,
|
|
|
|
|
read_timeout_seconds: timedelta,
|
|
|
|
|
) -> mcp.types.CallToolResult:
|
|
|
|
|
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tool_name: tool name
|
|
|
|
|
arguments: tool arguments
|
|
|
|
|
read_timeout_seconds: read timeout
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
MCP tool call result
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: MCP session is not available
|
|
|
|
|
anyio.ClosedResourceError: raised after reconnection failure
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@retry(
|
|
|
|
|
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
|
|
|
|
stop=stop_after_attempt(2),
|
|
|
|
|
wait=wait_exponential(multiplier=1, min=1, max=3),
|
|
|
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
|
|
|
reraise=True,
|
|
|
|
|
)
|
|
|
|
|
async def _call_with_retry():
|
|
|
|
|
if not self.session:
|
|
|
|
|
raise ValueError("MCP session is not available for MCP function tools.")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
return await self.session.call_tool(
|
|
|
|
|
name=tool_name,
|
|
|
|
|
arguments=arguments,
|
|
|
|
|
read_timeout_seconds=read_timeout_seconds,
|
|
|
|
|
)
|
|
|
|
|
except anyio.ClosedResourceError:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
|
|
|
|
)
|
|
|
|
|
# Attempt to reconnect
|
|
|
|
|
await self._reconnect()
|
|
|
|
|
# Reraise the exception to trigger tenacity retry
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
return await _call_with_retry()
|
|
|
|
|
|
|
|
|
|
async def cleanup(self):
|
|
|
|
|
"""Clean up resources"""
|
|
|
|
|
await self.exit_stack.aclose()
|
|
|
|
|
self.running_event.set() # Set the running event to indicate cleanup is done
|
|
|
|
|
"""Clean up resources including old exit stacks from reconnections"""
|
|
|
|
|
# Set running_event first to unblock any waiting tasks
|
|
|
|
|
self.running_event.set()
|
|
|
|
|
|
|
|
|
|
# Close current exit stack
|
|
|
|
|
try:
|
|
|
|
|
await self.exit_stack.aclose()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Error closing current exit stack: {e}")
|
|
|
|
|
|
|
|
|
|
# Don't close old exit stacks as they may be in different task contexts
|
|
|
|
|
# They will be garbage collected naturally
|
|
|
|
|
# Just clear the list to release references
|
|
|
|
|
self._old_exit_stacks.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCPTool(FunctionTool, Generic[TContext]):
|
|
|
|
|
@@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]):
|
|
|
|
|
async def call(
|
|
|
|
|
self, context: ContextWrapper[TContext], **kwargs
|
|
|
|
|
) -> mcp.types.CallToolResult:
|
|
|
|
|
session = self.mcp_client.session
|
|
|
|
|
if not session:
|
|
|
|
|
raise ValueError("MCP session is not available for MCP function tools.")
|
|
|
|
|
res = await session.call_tool(
|
|
|
|
|
name=self.mcp_tool.name,
|
|
|
|
|
return await self.mcp_client.call_tool_with_reconnect(
|
|
|
|
|
tool_name=self.mcp_tool.name,
|
|
|
|
|
arguments=kwargs,
|
|
|
|
|
read_timeout_seconds=timedelta(
|
|
|
|
|
seconds=context.tool_call_timeout,
|
|
|
|
|
),
|
|
|
|
|
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
|
|
|
|
|
)
|
|
|
|
|
return res
|
|
|
|
|
|