* refactor: llm message schema * feat: implement MCPTool and local LLM tools with enhanced context handling * refactor: reorganize imports and enhance docstrings for clarity * refactor: enhance ContentPart validation and add message pair handling in ConversationManager * chore: ruff format * refactor: remove debug print statement from payloads in ProviderOpenAIOfficial * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/tool.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/pipeline/process_stage/method/llm_request.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/agent/message.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: enhance documentation and import mcp in tool.py; update call method return type * fix: 修复以数据类的方式注册 tool 时的插件重载机制问题 * refactor: change role attributes to use Literal types for message segments * fix: add support for 'decorator_handler' method in call_local_llm_tool * fix: handle None prompt in text_chat method and ensure context is properly formatted --------- Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
260 lines
9.3 KiB
Python
260 lines
9.3 KiB
Python
import asyncio
|
|
import logging
|
|
from contextlib import AsyncExitStack
|
|
from datetime import timedelta
|
|
from typing import Generic
|
|
|
|
from astrbot import logger
|
|
from astrbot.core.agent.run_context import ContextWrapper
|
|
from astrbot.core.utils.log_pipe import LogPipe
|
|
|
|
from .run_context import TContext
|
|
from .tool import FunctionTool
|
|
|
|
try:
|
|
import mcp
|
|
from mcp.client.sse import sse_client
|
|
except (ModuleNotFoundError, ImportError):
|
|
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
|
|
|
try:
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
except (ModuleNotFoundError, ImportError):
|
|
logger.warning(
|
|
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
|
|
)
|
|
|
|
|
|
def _prepare_config(config: dict) -> dict:
|
|
"""准备配置,处理嵌套格式"""
|
|
if config.get("mcpServers"):
|
|
first_key = next(iter(config["mcpServers"]))
|
|
config = config["mcpServers"][first_key]
|
|
config.pop("active", None)
|
|
return config
|
|
|
|
|
|
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
"""快速测试 MCP 服务器可达性"""
|
|
import aiohttp
|
|
|
|
cfg = _prepare_config(config.copy())
|
|
|
|
url = cfg["url"]
|
|
headers = cfg.get("headers", {})
|
|
timeout = cfg.get("timeout", 10)
|
|
|
|
try:
|
|
if "transport" in cfg:
|
|
transport_type = cfg["transport"]
|
|
elif "type" in cfg:
|
|
transport_type = cfg["type"]
|
|
else:
|
|
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
if transport_type == "streamable_http":
|
|
test_payload = {
|
|
"jsonrpc": "2.0",
|
|
"method": "initialize",
|
|
"id": 0,
|
|
"params": {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {},
|
|
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
|
},
|
|
}
|
|
async with session.post(
|
|
url,
|
|
headers={
|
|
**headers,
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json, text/event-stream",
|
|
},
|
|
json=test_payload,
|
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
) as response:
|
|
if response.status == 200:
|
|
return True, ""
|
|
return False, f"HTTP {response.status}: {response.reason}"
|
|
else:
|
|
async with session.get(
|
|
url,
|
|
headers={
|
|
**headers,
|
|
"Accept": "application/json, text/event-stream",
|
|
},
|
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
) as response:
|
|
if response.status == 200:
|
|
return True, ""
|
|
return False, f"HTTP {response.status}: {response.reason}"
|
|
|
|
except asyncio.TimeoutError:
|
|
return False, f"连接超时: {timeout}秒"
|
|
except Exception as e:
|
|
return False, f"{e!s}"
|
|
|
|
|
|
class MCPClient:
|
|
def __init__(self):
|
|
# Initialize session and client objects
|
|
self.session: mcp.ClientSession | None = None
|
|
self.exit_stack = AsyncExitStack()
|
|
|
|
self.name: str | None = None
|
|
self.active: bool = True
|
|
self.tools: list[mcp.Tool] = []
|
|
self.server_errlogs: list[str] = []
|
|
self.running_event = asyncio.Event()
|
|
|
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
|
"""连接到 MCP 服务器
|
|
|
|
如果 `url` 参数存在:
|
|
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
|
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
|
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
|
|
|
Args:
|
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
|
|
|
"""
|
|
cfg = _prepare_config(mcp_server_config.copy())
|
|
|
|
def logging_callback(msg: str):
|
|
# 处理 MCP 服务的错误日志
|
|
print(f"MCP Server {name} Error: {msg}")
|
|
self.server_errlogs.append(msg)
|
|
|
|
if "url" in cfg:
|
|
success, error_msg = await _quick_test_mcp_connection(cfg)
|
|
if not success:
|
|
raise Exception(error_msg)
|
|
|
|
if "transport" in cfg:
|
|
transport_type = cfg["transport"]
|
|
elif "type" in cfg:
|
|
transport_type = cfg["type"]
|
|
else:
|
|
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
|
|
|
if transport_type != "streamable_http":
|
|
# SSE transport method
|
|
self._streams_context = sse_client(
|
|
url=cfg["url"],
|
|
headers=cfg.get("headers", {}),
|
|
timeout=cfg.get("timeout", 5),
|
|
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
|
)
|
|
streams = await self.exit_stack.enter_async_context(
|
|
self._streams_context,
|
|
)
|
|
|
|
# Create a new client session
|
|
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
|
self.session = await self.exit_stack.enter_async_context(
|
|
mcp.ClientSession(
|
|
*streams,
|
|
read_timeout_seconds=read_timeout,
|
|
logging_callback=logging_callback, # type: ignore
|
|
),
|
|
)
|
|
else:
|
|
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
|
sse_read_timeout = timedelta(
|
|
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
|
)
|
|
self._streams_context = streamablehttp_client(
|
|
url=cfg["url"],
|
|
headers=cfg.get("headers", {}),
|
|
timeout=timeout,
|
|
sse_read_timeout=sse_read_timeout,
|
|
terminate_on_close=cfg.get("terminate_on_close", True),
|
|
)
|
|
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
|
self._streams_context,
|
|
)
|
|
|
|
# Create a new client session
|
|
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
|
self.session = await self.exit_stack.enter_async_context(
|
|
mcp.ClientSession(
|
|
read_stream=read_s,
|
|
write_stream=write_s,
|
|
read_timeout_seconds=read_timeout,
|
|
logging_callback=logging_callback, # type: ignore
|
|
),
|
|
)
|
|
|
|
else:
|
|
server_params = mcp.StdioServerParameters(
|
|
**cfg,
|
|
)
|
|
|
|
def callback(msg: str):
|
|
# 处理 MCP 服务的错误日志
|
|
self.server_errlogs.append(msg)
|
|
|
|
stdio_transport = await self.exit_stack.enter_async_context(
|
|
mcp.stdio_client(
|
|
server_params,
|
|
errlog=LogPipe(
|
|
level=logging.ERROR,
|
|
logger=logger,
|
|
identifier=f"MCPServer-{name}",
|
|
callback=callback,
|
|
), # type: ignore
|
|
),
|
|
)
|
|
|
|
# Create a new client session
|
|
self.session = await self.exit_stack.enter_async_context(
|
|
mcp.ClientSession(*stdio_transport),
|
|
)
|
|
await self.session.initialize()
|
|
|
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
|
"""List all tools from the server and save them to self.tools"""
|
|
if not self.session:
|
|
raise Exception("MCP Client is not initialized")
|
|
response = await self.session.list_tools()
|
|
self.tools = response.tools
|
|
return response
|
|
|
|
async def cleanup(self):
|
|
"""Clean up resources"""
|
|
await self.exit_stack.aclose()
|
|
self.running_event.set() # Set the running event to indicate cleanup is done
|
|
|
|
|
|
class MCPTool(FunctionTool, Generic[TContext]):
|
|
"""A function tool that calls an MCP service."""
|
|
|
|
def __init__(
|
|
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
|
|
):
|
|
super().__init__(
|
|
name=mcp_tool.name,
|
|
description=mcp_tool.description or "",
|
|
parameters=mcp_tool.inputSchema,
|
|
)
|
|
self.mcp_tool = mcp_tool
|
|
self.mcp_client = mcp_client
|
|
self.mcp_server_name = mcp_server_name
|
|
|
|
async def call(
|
|
self, context: ContextWrapper[TContext], **kwargs
|
|
) -> mcp.types.CallToolResult:
|
|
session = self.mcp_client.session
|
|
if not session:
|
|
raise ValueError("MCP session is not available for MCP function tools.")
|
|
res = await session.call_tool(
|
|
name=self.mcp_tool.name,
|
|
arguments=kwargs,
|
|
read_timeout_seconds=timedelta(
|
|
seconds=context.tool_call_timeout,
|
|
),
|
|
)
|
|
return res
|