From 17d62a9af7872d469cf3e28897701d90ca9b9945 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sun, 20 Jul 2025 15:53:13 +0800
Subject: [PATCH 1/7] refactor: mcp server reload mechanism (#2161)
* refactor: mcp server reload mechanism
* fix: wait for client events
* fix: all other mcp servers are terminated when disable selected server
* fix: resolve type hinting issues in MCPClient and FuncCall methods
* perf: optimize mcp server loaders
* perf: improve MCP client connection testing
* perf: improve error message
* perf: clean code
* perf: increase default timeout for MCP connection and reset dialog message on close
---------
Co-authored-by: Raven95676
---
astrbot/core/provider/func_tool_manager.py | 359 ++++++++++++------
astrbot/core/provider/manager.py | 15 +-
astrbot/dashboard/routes/tools.py | 123 ++++--
dashboard/src/components/shared/ItemCard.vue | 6 +
.../i18n/locales/en-US/features/tool-use.json | 11 +-
.../i18n/locales/zh-CN/features/tool-use.json | 11 +-
dashboard/src/views/ToolUsePage.vue | 223 +++++++----
7 files changed, 509 insertions(+), 239 deletions(-)
diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py
index cea2e4f3..07a0fbd8 100644
--- a/astrbot/core/provider/func_tool_manager.py
+++ b/astrbot/core/provider/func_tool_manager.py
@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
+def _prepare_config(config: dict) -> dict:
+ """准备配置,处理嵌套格式"""
+ if "mcpServers" in config and config["mcpServers"]:
+ first_key = next(iter(config["mcpServers"]))
+ config = config["mcpServers"][first_key]
+ config.pop("active", None)
+ return config
+
+
+async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
+ """快速测试 MCP 服务器可达性"""
+ import aiohttp
+
+ cfg = _prepare_config(config.copy())
+
+ url = cfg["url"]
+ headers = cfg.get("headers", {})
+ timeout = cfg.get("timeout", 10)
+
+ try:
+ async with aiohttp.ClientSession() as session:
+ if cfg.get("transport") == "streamable_http":
+ test_payload = {
+ "jsonrpc": "2.0",
+ "method": "initialize",
+ "id": 0,
+ "params": {
+ "protocolVersion": "2024-11-05",
+ "capabilities": {},
+ "clientInfo": {"name": "test-client", "version": "1.2.3"},
+ },
+ }
+ async with session.post(
+ url,
+ headers={
+ **headers,
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ },
+ json=test_payload,
+ timeout=aiohttp.ClientTimeout(total=timeout),
+ ) as response:
+ if response.status == 200:
+ return True, ""
+ else:
+ return False, f"HTTP {response.status}: {response.reason}"
+ else:
+ async with session.get(
+ url,
+ headers={
+ **headers,
+ "Accept": "application/json, text/event-stream",
+ },
+ timeout=aiohttp.ClientTimeout(total=timeout),
+ ) as response:
+ if response.status == 200:
+ return True, ""
+ else:
+ return False, f"HTTP {response.status}: {response.reason}"
+
+ except asyncio.TimeoutError:
+ return False, f"连接超时: {timeout}秒"
+ except Exception as e:
+ return False, f"{e!s}"
+
+
@dataclass
class FuncTool:
"""
@@ -80,12 +146,10 @@ class FuncTool:
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
- if ":" in self.name:
- # 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
- actual_tool_name = self.name.split(":")[-1]
- return await self.mcp_client.session.call_tool(actual_tool_name, args)
- else:
- return await self.mcp_client.session.call_tool(self.name, args)
+ actual_tool_name = (
+ self.name.split(":")[-1] if ":" in self.name else self.name
+ )
+ return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
@@ -100,6 +164,7 @@ class MCPClient:
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
+ self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
@@ -112,17 +177,19 @@ class MCPClient:
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
- cfg = mcp_server_config.copy()
- if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
- key_0 = list(cfg["mcpServers"].keys())[0]
- cfg = cfg["mcpServers"][key_0]
- cfg.pop("active", None) # Remove active flag from config
+ cfg = _prepare_config(mcp_server_config.copy())
+
+ def logging_callback(msg: str):
+ # 处理 MCP 服务的错误日志
+ print(f"MCP Server {name} Error: {msg}")
+ self.server_errlogs.append(msg)
if "url" in cfg:
- is_sse = True
- if cfg.get("transport") == "streamable_http":
- is_sse = False
- if is_sse:
+ success, error_msg = await _quick_test_mcp_connection(cfg)
+ if not success:
+ raise Exception(error_msg)
+
+ if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
@@ -130,11 +197,18 @@ class MCPClient:
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
- streams = await self.exit_stack.enter_async_context(self._streams_context)
+ streams = await self.exit_stack.enter_async_context(
+ self._streams_context
+ )
# Create a new client session
+ read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
- mcp.ClientSession(*streams)
+ mcp.ClientSession(
+ *streams,
+ read_timeout_seconds=read_timeout,
+ logging_callback=logging_callback, # type: ignore
+ )
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
@@ -148,11 +222,19 @@ class MCPClient:
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
- read_s, write_s, _ = await self.exit_stack.enter_async_context(self._streams_context)
+ read_s, write_s, _ = await self.exit_stack.enter_async_context(
+ self._streams_context
+ )
# Create a new client session
+ read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
- mcp.ClientSession(read_stream=read_s, write_stream=write_s)
+ mcp.ClientSession(
+ read_stream=read_s,
+ write_stream=write_s,
+ read_timeout_seconds=read_timeout,
+ logging_callback=logging_callback, # type: ignore
+ )
)
else:
@@ -172,7 +254,7 @@ class MCPClient:
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
- ),
+ ), # type: ignore
),
)
@@ -180,19 +262,18 @@ class MCPClient:
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
-
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
- logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
+ self.running_event.set() # Set the running event to indicate cleanup is done
class FuncCall:
@@ -201,8 +282,6 @@ class FuncCall:
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
- self.mcp_service_queue = asyncio.Queue()
- """用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool:
@@ -258,7 +337,7 @@ class FuncCall:
return f
return None
- async def _init_mcp_clients(self) -> None:
+ async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
@@ -300,115 +379,64 @@ class FuncCall:
)
self.mcp_client_event[name] = event
- async def mcp_service_selector(self):
- """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
-
- 使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
-
- {"type": "init"} 初始化所有MCP客户端
-
- {"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
-
- {"type": "terminate"} 终止所有MCP客户端
-
- {"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
- """
- while True:
- data = await self.mcp_service_queue.get()
- if data["type"] == "init":
- if "name" in data:
- event = asyncio.Event()
- asyncio.create_task(
- self._init_mcp_client_task_wrapper(
- data["name"], data["cfg"], event
- )
- )
- self.mcp_client_event[data["name"]] = event
- else:
- await self._init_mcp_clients()
- elif data["type"] == "terminate":
- if "name" in data:
- # await self._terminate_mcp_client(data["name"])
- if data["name"] in self.mcp_client_event:
- self.mcp_client_event[data["name"]].set()
- self.mcp_client_event.pop(data["name"], None)
- self.func_list = [
- f
- for f in self.func_list
- if not (
- f.origin == "mcp" and f.mcp_server_name == data["name"]
- )
- ]
- else:
- for name in self.mcp_client_dict.keys():
- # await self._terminate_mcp_client(name)
- # self.mcp_client_event[name].set()
- if name in self.mcp_client_event:
- self.mcp_client_event[name].set()
- self.mcp_client_event.pop(name, None)
- self.func_list = [f for f in self.func_list if f.origin != "mcp"]
-
async def _init_mcp_client_task_wrapper(
- self, name: str, cfg: dict, event: asyncio.Event
+ self,
+ name: str,
+ cfg: dict,
+ event: asyncio.Event,
+ ready_future: asyncio.Future = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
+ tools = await self.mcp_client_dict[name].list_tools_and_save()
+ if ready_future and not ready_future.done():
+ # tell the caller we are ready
+ ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception as e:
- import traceback
-
- traceback.print_exc()
- logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
+ logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
+ if ready_future and not ready_future.done():
+ ready_future.set_exception(e)
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
- try:
- # 先清理之前的客户端,如果存在
- if name in self.mcp_client_dict:
- await self._terminate_mcp_client(name)
+ # 先清理之前的客户端,如果存在
+ if name in self.mcp_client_dict:
+ await self._terminate_mcp_client(name)
- mcp_client = MCPClient()
- mcp_client.name = name
- self.mcp_client_dict[name] = mcp_client
- await mcp_client.connect_to_server(config, name)
- tools_res = await mcp_client.list_tools_and_save()
- tool_names = [tool.name for tool in tools_res.tools]
+ mcp_client = MCPClient()
+ mcp_client.name = name
+ self.mcp_client_dict[name] = mcp_client
+ await mcp_client.connect_to_server(config, name)
+ tools_res = await mcp_client.list_tools_and_save()
+ logger.debug(f"MCP server {name} list tools response: {tools_res}")
+ tool_names = [tool.name for tool in tools_res.tools]
- # 移除该MCP服务之前的工具(如有)
- self.func_list = [
- f
- for f in self.func_list
- if not (f.origin == "mcp" and f.mcp_server_name == name)
- ]
+ # 移除该MCP服务之前的工具(如有)
+ self.func_list = [
+ f
+ for f in self.func_list
+ if not (f.origin == "mcp" and f.mcp_server_name == name)
+ ]
- # 将 MCP 工具转换为 FuncTool 并添加到 func_list
- for tool in mcp_client.tools:
- func_tool = FuncTool(
- name=tool.name,
- parameters=tool.inputSchema,
- description=tool.description,
- origin="mcp",
- mcp_server_name=name,
- mcp_client=mcp_client,
- )
- self.func_list.append(func_tool)
+ # 将 MCP 工具转换为 FuncTool 并添加到 func_list
+ for tool in mcp_client.tools:
+ func_tool = FuncTool(
+ name=tool.name,
+ parameters=tool.inputSchema,
+ description=tool.description,
+ origin="mcp",
+ mcp_server_name=name,
+ mcp_client=mcp_client,
+ )
+ self.func_list.append(func_tool)
- logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
- return
- except Exception as e:
- import traceback
-
- logger.error(traceback.format_exc())
- logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
- # 发生错误时确保客户端被清理
- if name in self.mcp_client_dict:
- await self._terminate_mcp_client(name)
- return
+ logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
@@ -418,7 +446,7 @@ class FuncCall:
await self.mcp_client_dict[name].cleanup()
self.mcp_client_dict.pop(name)
except Exception as e:
- logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
+ logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
# 移除关联的FuncTool
self.func_list = [
f
@@ -427,6 +455,103 @@ class FuncCall:
]
logger.info(f"已关闭 MCP 服务 {name}")
+ @staticmethod
+ async def test_mcp_server_connection(config: dict) -> list[str]:
+ if "url" in config:
+ success, error_msg = await _quick_test_mcp_connection(config)
+ if not success:
+ raise Exception(error_msg)
+
+ mcp_client = MCPClient()
+ try:
+ logger.debug(f"testing MCP server connection with config: {config}")
+ await mcp_client.connect_to_server(config, "test")
+ tools_res = await mcp_client.list_tools_and_save()
+ tool_names = [tool.name for tool in tools_res.tools]
+ finally:
+ logger.debug("Cleaning up MCP client after testing connection.")
+ await mcp_client.cleanup()
+ return tool_names
+
+ async def enable_mcp_server(
+ self,
+ name: str,
+ config: dict,
+ event: asyncio.Event | None = None,
+ ready_future: asyncio.Future | None = None,
+ timeout: int = 30,
+ ) -> None:
+ """Enable_mcp_server a new MCP server to the manager and initialize it.
+
+ Args:
+ name (str): The name of the MCP server.
+ config (dict): Configuration for the MCP server.
+ event (asyncio.Event): Event to signal when the MCP client is ready.
+ ready_future (asyncio.Future): Future to signal when the MCP client is ready.
+ timeout (int): Timeout for the initialization.
+ Raises:
+ TimeoutError: If the initialization does not complete within the specified timeout.
+ Exception: If there is an error during initialization.
+ """
+ if not event:
+ event = asyncio.Event()
+ if not ready_future:
+ ready_future = asyncio.Future()
+ if name in self.mcp_client_dict:
+ return
+ asyncio.create_task(
+ self._init_mcp_client_task_wrapper(name, config, event, ready_future)
+ )
+ try:
+ await asyncio.wait_for(ready_future, timeout=timeout)
+ finally:
+ self.mcp_client_event[name] = event
+
+ if ready_future.done() and ready_future.exception():
+ exc = ready_future.exception()
+ if exc is not None:
+ raise exc
+
+ async def disable_mcp_server(
+ self, name: str | None = None, timeout: float = 10
+ ) -> None:
+ """Disable an MCP server by its name.
+
+ Args:
+ name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
+ timeout (int): Timeout.
+ """
+ if name:
+ if name not in self.mcp_client_event:
+ return
+ client = self.mcp_client_dict.get(name)
+ self.mcp_client_event[name].set()
+ if not client:
+ return
+ client_running_event = client.running_event
+ try:
+ await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
+ finally:
+ self.mcp_client_event.pop(name, None)
+ self.func_list = [
+ f
+ for f in self.func_list
+ if f.origin != "mcp" or f.mcp_server_name != name
+ ]
+ else:
+ running_events = [
+ client.running_event.wait() for client in self.mcp_client_dict.values()
+ ]
+ for key, event in self.mcp_client_event.items():
+ event.set()
+ # waiting for all clients to finish
+ try:
+ await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
+ finally:
+ self.mcp_client_event.clear()
+ self.mcp_client_dict.clear()
+ self.func_list = [f for f in self.func_list if f.origin != "mcp"]
+
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index df21e6a1..370c5322 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -169,10 +169,7 @@ class ProviderManager:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
- asyncio.create_task(
- self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
- )
- self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
+ asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
@@ -422,7 +419,7 @@ class ProviderManager:
self.curr_tts_provider_inst = None
if getattr(self.inst_map[provider_id], "terminate", None):
- await self.inst_map[provider_id].terminate() # type: ignore
+ await self.inst_map[provider_id].terminate() # type: ignore
logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
@@ -432,6 +429,8 @@ class ProviderManager:
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
- await provider_inst.terminate() # type: ignore
- # 清理 MCP Client 连接
- await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
+ await provider_inst.terminate() # type: ignore
+ try:
+ await self.llm_tools.disable_mcp_server()
+ except Exception:
+ logger.error("Error while disabling MCP servers", exc_info=True)
diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py
index d38014c7..5dad2576 100644
--- a/astrbot/dashboard/routes/tools.py
+++ b/astrbot/dashboard/routes/tools.py
@@ -26,6 +26,7 @@ class ToolsRoute(Route):
"/tools/mcp/update": ("POST", self.update_mcp_server),
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
"/tools/mcp/market": ("GET", self.get_mcp_markets),
+ "/tools/mcp/test": ("POST", self.test_mcp_connection),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
@@ -132,12 +133,19 @@ class ToolsRoute(Route):
config["mcpServers"][name] = server_config
if self.save_mcp_config(config):
- # 动态初始化新MCP客户端
- await self.tool_mgr.mcp_service_queue.put({
- "type": "init",
- "name": name,
- "cfg": config["mcpServers"][name],
- })
+ try:
+ await self.tool_mgr.enable_mcp_server(
+ name, server_config, timeout=30
+ )
+ except TimeoutError:
+ return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
@@ -193,31 +201,55 @@ class ToolsRoute(Route):
if self.save_mcp_config(config):
# 处理MCP客户端状态变化
if active:
- # 如果要激活服务器或者配置已更改
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
- await self.tool_mgr.mcp_service_queue.put({
- "type": "terminate",
- "name": name,
- })
- await self.tool_mgr.mcp_service_queue.put({
- "type": "init",
- "name": name,
- "cfg": config["mcpServers"][name],
- })
- else:
- # 客户端不存在,初始化
- await self.tool_mgr.mcp_service_queue.put({
- "type": "init",
- "name": name,
- "cfg": config["mcpServers"][name],
- })
+ try:
+ await self.tool_mgr.disable_mcp_server(name, timeout=10)
+ except TimeoutError as e:
+ return (
+ Response()
+ .error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}")
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}")
+ .__dict__
+ )
+ try:
+ await self.tool_mgr.enable_mcp_server(
+ name, config["mcpServers"][name], timeout=30
+ )
+ except TimeoutError:
+ return (
+ Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
else:
# 如果要停用服务器
if name in self.tool_mgr.mcp_client_dict:
- self.tool_mgr.mcp_service_queue.put_nowait({
- "type": "terminate",
- "name": name,
- })
+ try:
+ await self.tool_mgr.disable_mcp_server(name, timeout=10)
+ except TimeoutError:
+ return (
+ Response()
+ .error(f"停用 MCP 服务器 {name} 超时。")
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
else:
@@ -239,17 +271,23 @@ class ToolsRoute(Route):
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
- # 删除服务器配置
del config["mcpServers"][name]
if self.save_mcp_config(config):
- # 关闭并删除MCP客户端
if name in self.tool_mgr.mcp_client_dict:
- self.tool_mgr.mcp_service_queue.put_nowait({
- "type": "terminate",
- "name": name,
- })
-
+ try:
+ await self.tool_mgr.disable_mcp_server(name, timeout=10)
+ except TimeoutError:
+ return (
+ Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
@@ -281,3 +319,20 @@ class ToolsRoute(Route):
except Exception as _:
logger.error(traceback.format_exc())
return Response().error("获取市场数据失败").__dict__
+
+ async def test_mcp_connection(self):
+ """
+ 测试 MCP 服务器连接
+ """
+ try:
+ server_data = await request.json
+ config = server_data.get("mcp_server_config", None)
+
+ tools_name = await self.tool_mgr.test_mcp_server_connection(config)
+ return (
+ Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
+ )
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
diff --git a/dashboard/src/components/shared/ItemCard.vue b/dashboard/src/components/shared/ItemCard.vue
index ff790cb7..6152c531 100644
--- a/dashboard/src/components/shared/ItemCard.vue
+++ b/dashboard/src/components/shared/ItemCard.vue
@@ -9,6 +9,8 @@
hide-details
density="compact"
:model-value="getItemEnabled()"
+ :loading="loading"
+ :disabled="loading"
v-bind="props"
@update:model-value="toggleEnabled"
>
@@ -77,6 +79,10 @@ export default {
bglogo: {
type: String,
default: null
+ },
+ loading: {
+ type: Boolean,
+ default: false
}
},
emits: ['toggle-enabled', 'delete', 'edit'],
diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json
index fad67a0d..bd36fd68 100644
--- a/dashboard/src/i18n/locales/en-US/features/tool-use.json
+++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json
@@ -15,7 +15,9 @@
"buttons": {
"refresh": "Refresh",
"add": "Add Server",
- "useTemplate": "Use Template"
+ "useTemplateStdio": "Stdio Template",
+ "useTemplateStreamableHttp": "Streamable HTTP Template",
+ "useTemplateSse": "SSE Template"
},
"empty": "No MCP servers available, click Add Server to add one",
"status": {
@@ -68,10 +70,6 @@
"enable": "Enable Server",
"config": "Server Configuration"
},
- "configNotes": {
- "note1": "1. Some MCP servers may require filling in `API_KEY` or `TOKEN` information in env according to their requirements, please check if filled.",
- "note2": "2. When url parameter is specified in configuration: if `transport` parameter is also specified as `streamable_http`, Streamable HTTP is used, otherwise SSE connection is used."
- },
"errors": {
"configEmpty": "Configuration cannot be empty",
"jsonFormat": "JSON format error: {error}",
@@ -79,7 +77,8 @@
},
"buttons": {
"cancel": "Cancel",
- "save": "Save"
+ "save": "Save",
+ "testConnection": "Test Connection"
}
},
"serverDetail": {
diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
index f44a16d5..c9e8e858 100644
--- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
+++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
@@ -15,7 +15,9 @@
"buttons": {
"refresh": "刷新",
"add": "新增服务器",
- "useTemplate": "使用模板"
+ "useTemplateStdio": "Stdio 模板",
+ "useTemplateStreamableHttp": "Streamable HTTP 模板",
+ "useTemplateSse": "SSE 模板"
},
"empty": "暂无 MCP 服务器,点击 新增服务器 添加",
"status": {
@@ -68,10 +70,6 @@
"enable": "启用服务器",
"config": "服务器配置"
},
- "configNotes": {
- "note1": "1. 某些 MCP 服务器可能需要按照其要求在 env 中填充 `API_KEY` 或 `TOKEN` 等信息,请注意检查是否填写。",
- "note2": "2. 当配置中指定 url 参数时:如果还同时指定 `transport` 参数的值为 `streamable_http`,则使用 Steamable HTTP,否则使用 SSE 连接。"
- },
"errors": {
"configEmpty": "配置不能为空",
"jsonFormat": "JSON 格式错误: {error}",
@@ -79,7 +77,8 @@
},
"buttons": {
"cancel": "取消",
- "save": "保存"
+ "save": "保存",
+ "testConnection": "测试连接"
}
},
"serverDetail": {
diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue
index cc70c415..93f53924 100644
--- a/dashboard/src/views/ToolUsePage.vue
+++ b/dashboard/src/views/ToolUsePage.vue
@@ -20,7 +20,8 @@
-
+
{{ tm('mcpServers.buttons.add') }}
@@ -49,7 +50,8 @@
mdi-server
{{ tm('mcpServers.title') }}
-
+
{{ tm('mcpServers.buttons.refresh') }}
-
+
mdi-file-code
-
+
{{ getServerConfigSummary(item) }}
-
-
-
mdi-tools
-
{{ tm('mcpServers.status.availableTools') }} ({{ item.tools.length }})
+
+
+
+
+
+
mdi-tools
+
+
+
+ {{ tm('mcpServers.status.availableTools', { count: item.tools.length }) }} ({{
+ item.tools.length }})
+
+
+
+
+
+ {{ tm('mcpServers.status.availableTools') }}
+
+
+
+
+
+
+ Close
+
+
+
+
+
+
+
+
+
+
+ mdi-alert-circle
+ {{ tm('mcpServers.status.noTools') }}
+
+
+
+
-
-
- {{ tool }}
-
-
-
-
- mdi-alert-circle
- {{ tm('mcpServers.status.noTools') }}
+
+
@@ -131,8 +162,9 @@
-
+
mdi-store
{{ tm('marketplace.title') }}
-
+
{{ tm('marketplace.buttons.refresh') }}
@@ -256,7 +288,8 @@
mdi-tools
- {{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 }) }}
+ {{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 })
+ }}
@@ -310,31 +343,25 @@
-
-
-
+
{{ tm('dialogs.addServer.fields.config') }}
-
-
- mdi-information
-
-
- {{ tm('tooltip.serverConfig') }}
-
-
-
- {{ tm('mcpServers.buttons.useTemplate') }}
+
+ {{ tm('mcpServers.buttons.useTemplateStdio') }}
+
+
+ {{ tm('mcpServers.buttons.useTemplateStreamableHttp') }}
+
+
+ {{ tm('mcpServers.buttons.useTemplateSse') }}
- {{ tm('dialogs.addServer.configNotes.note1') }}
-
- {{ tm('dialogs.addServer.configNotes.note2') }}
-
+
+ {{ addServerDialogMessage }}
+
-
+
{{ tm('dialogs.addServer.buttons.cancel') }}
+
+ {{ tm('dialogs.addServer.buttons.testConnection') }}
+
{{ tm('dialogs.addServer.buttons.save') }}
@@ -504,8 +536,11 @@ export default {
tools: [],
showMcpServerDialog: false,
showServerDetailDialog: false,
+ addServerDialogMessage: "",
showTools: true,
loading: false,
+ loadingGettingServers: false,
+ mcpServerUpdateLoaders: {}, // record loading state for each server update
isEditMode: false,
serverConfigJson: '',
jsonError: null,
@@ -575,10 +610,10 @@ export default {
if (!this.marketplaceSearch.trim()) {
return this.marketplaceServers;
}
-
+
const searchTerm = this.marketplaceSearch.toLowerCase();
- return this.marketplaceServers.filter(server =>
- server.name.toLowerCase().includes(searchTerm) ||
+ return this.marketplaceServers.filter(server =>
+ server.name.toLowerCase().includes(searchTerm) ||
(server.name_h && server.name_h.toLowerCase().includes(searchTerm)) ||
(server.description && server.description.toLowerCase().includes(searchTerm))
);
@@ -618,17 +653,21 @@ export default {
},
getServers() {
- this.loading = true
+ this.loadingGettingServers = true;
axios.get('/api/tools/mcp/servers')
.then(response => {
this.mcpServers = response.data.data || [];
+ this.mcpServers.forEach(server => {
+ // Ensure each server has a loader state
+ if (!this.mcpServerUpdateLoaders[server.name]) {
+ this.mcpServerUpdateLoaders[server.name] = false;
+ }
+ });
})
.catch(error => {
this.showError(this.tm('messages.getServersError', { error: error.message }));
}).finally(() => {
- setTimeout(() => {
- this.loading = false;
- }, 500);
+ this.loadingGettingServers = false;
});
},
@@ -658,14 +697,28 @@ export default {
}
},
- setConfigTemplate() {
- // 设置一个基本的配置模板
- const template = {
- command: "python",
- args: ["-m", "your_module"],
- // 可以添加其他 MCP 支持的配置项
- };
-
+ setConfigTemplate(type = 'stdio') {
+ let template = {};
+ if (type === 'streamable_http') {
+ template = {
+ transport: "streamable_http",
+ url: "your mcp server url",
+ headers: {},
+ timeout: 30,
+ };
+ } else if (type === 'sse') {
+ template = {
+ transport: "sse",
+ url: "your mcp server url",
+ headers: {},
+ timeout: 30,
+ };
+ } else {
+ template = {
+ command: "python",
+ args: ["-m", "your_module"],
+ };
+ }
this.serverConfigJson = JSON.stringify(template, null, 2);
},
@@ -693,6 +746,7 @@ export default {
.then(response => {
this.loading = false;
this.showMcpServerDialog = false;
+ this.addServerDialogMessage = "";
this.getServers();
this.getTools();
this.showSuccess(response.data.message || this.tm('messages.saveSuccess'));
@@ -753,6 +807,7 @@ export default {
updateServerStatus(server) {
// 切换服务器状态
+ this.mcpServerUpdateLoaders[server.name] = true;
server.active = !server.active;
axios.post('/api/tools/mcp/update', server)
.then(response => {
@@ -761,16 +816,48 @@ export default {
})
.catch(error => {
this.showError(this.tm('messages.updateError', { error: error.response?.data?.message || error.message }));
- // 回滚状态
server.active = !server.active;
+ })
+ .finally(() => {
+ this.mcpServerUpdateLoaders[server.name] = false;
});
},
closeServerDialog() {
this.showMcpServerDialog = false;
+ this.addServerDialogMessage = '';
this.resetForm();
},
+ testServerConnection() {
+ if (!this.validateJson()) {
+ return;
+ }
+
+ this.loading = true;
+
+ let configObj;
+ try {
+ configObj = JSON.parse(this.serverConfigJson);
+ } catch (e) {
+ this.loading = false;
+ this.showError(this.tm('dialogs.addServer.errors.jsonParse', { error: e.message }));
+ return;
+ }
+
+ axios.post('/api/tools/mcp/test', {
+ "mcp_server_config": configObj,
+ })
+ .then(response => {
+ this.loading = false;
+ this.addServerDialogMessage = `${response.data.message} (tools: ${response.data.data})`;
+ })
+ .catch(error => {
+ this.loading = false;
+ this.showError(this.tm('messages.testError', { error: error.response?.data?.message || error.message }));
+ });
+ },
+
resetForm() {
this.currentServer = {
name: '',
@@ -939,7 +1026,7 @@ export default {
.monaco-container {
border: 1px solid rgba(0, 0, 0, 0.1);
- border-radius: 4px;
+ border-radius: 8px;
height: 300px;
margin-top: 4px;
overflow: hidden;
From b5d8173ee3798d638514ba45c2fabeb3733bf876 Mon Sep 17 00:00:00 2001
From: RC-CHN <67079377+RC-CHN@users.noreply.github.com>
Date: Sun, 20 Jul 2025 16:02:28 +0800
Subject: [PATCH 2/7] feat: add a file uplod button in WebChat page (#2136)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)
* fix:上传后清空value,允许触发change事件以多次上传同一张图片
* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致
* perf:将文件输入的值重置为空字符串以提升浏览器兼容性
* feat:webchat文件上传按钮支持多选文件上传
* fix:释放blob URL以防止内存泄漏
* perf:并行化sendMessage中的图片获取逻辑
---
dashboard/src/views/ChatPage.vue | 113 ++++++++++++++++++++-----------
1 file changed, 72 insertions(+), 41 deletions(-)
diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue
index 099cf09c..ad75cbd3 100644
--- a/dashboard/src/views/ChatPage.vue
+++ b/dashboard/src/views/ChatPage.vue
@@ -226,6 +226,9 @@
+
+
@@ -668,34 +671,44 @@ export default {
};
},
+ async processAndUploadImage(file) {
+ const formData = new FormData();
+ formData.append('file', file);
+
+ try {
+ const response = await axios.post('/api/chat/post_image', formData, {
+ headers: {
+ 'Content-Type': 'multipart/form-data'
+ }
+ });
+
+ const img = response.data.data.filename;
+ this.stagedImagesName.push(img); // Store just the filename
+ this.stagedImagesUrl.push(URL.createObjectURL(file)); // Create a blob URL for immediate display
+
+ } catch (err) {
+ console.error('Error uploading image:', err);
+ }
+ },
+
async handlePaste(event) {
console.log('Pasting image...');
const items = event.clipboardData.items;
for (let i = 0; i < items.length; i++) {
if (items[i].type.indexOf('image') !== -1) {
const file = items[i].getAsFile();
- const formData = new FormData();
- formData.append('file', file);
-
- try {
- const response = await axios.post('/api/chat/post_image', formData, {
- headers: {
- 'Content-Type': 'multipart/form-data'
- }
- });
-
- const img = response.data.data.filename;
- this.stagedImagesName.push(img); // Store just the filename
- this.stagedImagesUrl.push(URL.createObjectURL(file)); // Create a blob URL for immediate display
-
- } catch (err) {
- console.error('Error uploading image:', err);
- }
+ this.processAndUploadImage(file);
}
}
},
removeImage(index) {
+ // Revoke the blob URL to prevent memory leaks
+ const urlToRevoke = this.stagedImagesUrl[index];
+ if (urlToRevoke && urlToRevoke.startsWith('blob:')) {
+ URL.revokeObjectURL(urlToRevoke);
+ }
+
this.stagedImagesName.splice(index, 1);
this.stagedImagesUrl.splice(index, 1);
},
@@ -703,6 +716,21 @@ export default {
clearMessage() {
this.prompt = '';
},
+
+ triggerImageInput() {
+ this.$refs.imageInput.click();
+ },
+
+ handleFileSelect(event) {
+ const files = event.target.files;
+ if (files) {
+ for (const file of files) {
+ this.processAndUploadImage(file);
+ }
+ }
+ // Reset the input value to allow selecting the same file again
+ event.target.value = '';
+ },
getConversations() {
axios.get('/api/chat/conversations').then(response => {
this.conversations = response.data.data;
@@ -846,33 +874,42 @@ export default {
// URL is already updated in newConversation method
}
+ // 保存当前要发送的数据到临时变量
+ const promptToSend = this.prompt.trim();
+ const imageNamesToSend = [...this.stagedImagesName];
+ const audioNameToSend = this.stagedAudioUrl;
+
+ // 立即清空输入和附件预览
+ this.prompt = '';
+ this.stagedImagesName = [];
+ this.stagedImagesUrl = [];
+ this.stagedAudioUrl = "";
+
// Create a message object with actual URLs for display
const userMessage = {
type: 'user',
- message: this.prompt.trim(), // 使用 trim() 去除前后空格
+ message: promptToSend,
image_url: [],
audio_url: null
};
// Convert image filenames to blob URLs for display
- if (this.stagedImagesName.length > 0) {
- for (let i = 0; i < this.stagedImagesName.length; i++) {
- // If it's just a filename, get the blob URL
- if (!this.stagedImagesName[i].startsWith('blob:')) {
- const imgUrl = await this.getMediaFile(this.stagedImagesName[i]);
- userMessage.image_url.push(imgUrl);
- } else {
- userMessage.image_url.push(this.stagedImagesName[i]);
+ if (imageNamesToSend.length > 0) {
+ const imagePromises = imageNamesToSend.map(name => {
+ if (!name.startsWith('blob:')) {
+ return this.getMediaFile(name);
}
- }
+ return Promise.resolve(name);
+ });
+ userMessage.image_url = await Promise.all(imagePromises);
}
// Convert audio filename to blob URL for display
- if (this.stagedAudioUrl) {
- if (!this.stagedAudioUrl.startsWith('blob:')) {
- userMessage.audio_url = await this.getMediaFile(this.stagedAudioUrl);
+ if (audioNameToSend) {
+ if (!audioNameToSend.startsWith('blob:')) {
+ userMessage.audio_url = await this.getMediaFile(audioNameToSend);
} else {
- userMessage.audio_url = this.stagedAudioUrl;
+ userMessage.audio_url = audioNameToSend;
}
}
@@ -885,8 +922,6 @@ export default {
const selection = this.$refs.providerModelSelector?.getCurrentSelection();
const selectedProviderId = selection?.providerId || '';
const selectedModelName = selection?.modelName || '';
- let prompt = this.prompt.trim();
- this.prompt = ''; // 清空输入框
try {
const response = await fetch('/api/chat/send', {
@@ -896,10 +931,10 @@ export default {
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
body: JSON.stringify({
- message: prompt,
+ message: promptToSend,
conversation_id: this.currCid,
- image_url: this.stagedImagesName,
- audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [],
+ image_url: imageNamesToSend,
+ audio_url: audioNameToSend ? [audioNameToSend] : [],
selected_provider: selectedProviderId,
selected_model: selectedModelName
})
@@ -1003,11 +1038,7 @@ export default {
}
}
- // Clear input after successful send
- this.prompt = '';
- this.stagedImagesName = [];
- this.stagedImagesUrl = [];
- this.stagedAudioUrl = "";
+ // Input and attachments are already cleared
this.loadingChat = false;
// get the latest conversations
From 28d78643939348042fab2b595253182fe05b4506 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sun, 20 Jul 2025 20:24:03 +0800
Subject: [PATCH 3/7] perf: tool use page UI (#2182)
* perf: tool use UI
* fix: update background color of item cards in ToolUsePage
---
.../i18n/locales/en-US/features/tool-use.json | 3 +-
.../i18n/locales/zh-CN/features/tool-use.json | 3 +-
dashboard/src/theme/DarkTheme.ts | 3 +-
dashboard/src/theme/LightTheme.ts | 5 +-
dashboard/src/types/themeTypes/ThemeType.ts | 1 +
dashboard/src/views/ToolUsePage.vue | 366 +++++++++---------
6 files changed, 184 insertions(+), 197 deletions(-)
diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json
index bd36fd68..96c4760e 100644
--- a/dashboard/src/i18n/locales/en-US/features/tool-use.json
+++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json
@@ -30,8 +30,7 @@
"functionTools": {
"title": "Function Tools",
"buttons": {
- "expand": "Expand",
- "collapse": "Collapse"
+ "view": "View Tools"
},
"search": "Search function tools",
"empty": "No function tools available",
diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
index c9e8e858..61b8691b 100644
--- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
+++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
@@ -30,8 +30,7 @@
"functionTools": {
"title": "函数工具",
"buttons": {
- "expand": "展开",
- "collapse": "收起"
+ "view": "查看工具"
},
"search": "搜索函数工具",
"empty": "没有可用的函数工具",
diff --git a/dashboard/src/theme/DarkTheme.ts b/dashboard/src/theme/DarkTheme.ts
index 9899fcff..177bee39 100644
--- a/dashboard/src/theme/DarkTheme.ts
+++ b/dashboard/src/theme/DarkTheme.ts
@@ -36,12 +36,13 @@ const PurpleThemeDark: ThemeTypes = {
gray100: '#cccccccc',
primary200: '#90caf9',
secondary200: '#b39ddb',
- background: '#111111',
+ background: '#1d1d1d',
overlay: '#111111aa',
codeBg: '#282833',
preBg: 'rgb(23, 23, 23)',
code: '#ffffffdd',
chatMessageBubble: '#2d2e30',
+ mcpCardBg: '#2a2a2a',
}
};
diff --git a/dashboard/src/theme/LightTheme.ts b/dashboard/src/theme/LightTheme.ts
index 03630523..b8fdec25 100644
--- a/dashboard/src/theme/LightTheme.ts
+++ b/dashboard/src/theme/LightTheme.ts
@@ -27,7 +27,7 @@ const PurpleTheme: ThemeTypes = {
borderLight: '#d0d0d0',
border: '#d0d0d0',
inputBorder: '#787878',
- containerBg: '#f7f1f6',
+ containerBg: '#f9fafcf4',
surface: '#fff',
'on-surface-variant': '#fff',
facebook: '#4267b2',
@@ -36,12 +36,13 @@ const PurpleTheme: ThemeTypes = {
gray100: '#fafafacc',
primary200: '#90caf9',
secondary200: '#b39ddb',
- background: '#f9fafcf4',
+ background: '#ffffff',
overlay: '#ffffffaa',
codeBg: '#ececec',
preBg: 'rgb(249, 249, 249)',
code: 'rgb(13, 13, 13)',
chatMessageBubble: '#e7ebf4',
+ mcpCardBg: '#f7f2f9',
}
};
diff --git a/dashboard/src/types/themeTypes/ThemeType.ts b/dashboard/src/types/themeTypes/ThemeType.ts
index b18ee3dc..8d276004 100644
--- a/dashboard/src/types/themeTypes/ThemeType.ts
+++ b/dashboard/src/types/themeTypes/ThemeType.ts
@@ -37,5 +37,6 @@ export type ThemeTypes = {
preBg?: string;
code?: string;
chatMessageBubble?: string;
+ mcpCardBg?: string;
};
};
diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue
index 93f53924..6060088a 100644
--- a/dashboard/src/views/ToolUsePage.vue
+++ b/dashboard/src/views/ToolUsePage.vue
@@ -20,10 +20,16 @@
-
- {{ tm('mcpServers.buttons.add') }}
-
+
+
+ {{ tm('functionTools.buttons.view') }}({{ tools.length }})
+
+
+ {{ tm('mcpServers.buttons.add') }}
+
+
@@ -45,200 +51,79 @@
-
-
- mdi-server
- {{ tm('mcpServers.title') }}
-
-
- {{ tm('mcpServers.buttons.refresh') }}
-
-
- {{ tm('mcpServers.buttons.add') }}
-
-
-
+
+
mdi-server-off
+
{{ tm('mcpServers.empty') }}
+
-
-
-
mdi-server-off
-
{{ tm('mcpServers.empty') }}
-
-
-
-
-
-
-
- mdi-file-code
-
- {{ getServerConfigSummary(item) }}
-
-
-
-
-
-
-
-
-
mdi-tools
-
-
-
- {{ tm('mcpServers.status.availableTools', { count: item.tools.length }) }} ({{
- item.tools.length }})
-
-
-
-
-
- {{ tm('mcpServers.status.availableTools') }}
-
-
-
-
-
-
- Close
-
-
-
-
-
-
-
-
-
-
- mdi-alert-circle
- {{ tm('mcpServers.status.noTools') }}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- mdi-function
- {{ tm('functionTools.title') }}
- {{ tools.length }}
-
-
- {{ showTools ? tm('functionTools.buttons.collapse') : tm('functionTools.buttons.expand') }}
- {{ showTools ? 'mdi-chevron-up' : 'mdi-chevron-down' }}
-
-
-
-
-
-
-
-
-
-
mdi-api-off
-
{{ tm('functionTools.empty') }}
+
+
+
+
+
+ mdi-file-code
+
+ {{ getServerConfigSummary(item) }}
+
-
-
-
-
-
-
-
-
-
- {{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
-
-
- {{ formatToolName(tool.function.name) }}
-
-
-
-
- {{ tool.function.description }}
-
-
-
+
+
+
+
+
mdi-tools
+
+
+
+ {{ tm('mcpServers.status.availableTools', { count: item.tools.length }) }} ({{
+ item.tools.length }})
+
+
+
+
+
+ {{ tm('mcpServers.status.availableTools') }}
+
+
+
+
+
+
+ Close
+
+
+
+
-
-
-
-
- mdi-information
- {{ tm('functionTools.description') }}
-
- {{ tool.function.description }}
-
-
- mdi-code-json
- {{ tm('functionTools.parameters') }}
-
-
-
-
-
- | {{ tm('functionTools.table.paramName') }} |
- {{ tm('functionTools.table.type') }} |
- {{ tm('functionTools.table.description') }} |
-
-
-
-
- | {{ paramName }} |
-
-
- {{ param.type }}
-
- |
- {{ param.description }} |
-
-
-
-
-
-
mdi-code-brackets
-
{{ tm('functionTools.noParameters') }}
-
-
-
-
-
-
+
+
+
+
+ mdi-alert-circle
+ {{ tm('mcpServers.status.noTools') }}
+
+
+
+
+
-
-
-
-
+
+
+
+
+
+
+
@@ -501,6 +386,106 @@
+
+
+
+
+ {{ tm('functionTools.title') }}
+ {{ tools.length }}
+
+
+
+
+
+
mdi-api-off
+
{{ tm('functionTools.empty') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
+
+
+ {{ formatToolName(tool.function.name) }}
+
+
+
+
+ {{ tool.function.description }}
+
+
+
+
+
+
+
+
+ mdi-information
+ {{ tm('functionTools.description') }}
+
+ {{ tool.function.description }}
+
+
+
+ mdi-code-json
+ {{ tm('functionTools.parameters') }}
+
+
+
+
+
+ | {{ tm('functionTools.table.paramName') }} |
+ {{ tm('functionTools.table.type') }} |
+ {{ tm('functionTools.table.description') }} |
+
+
+
+
+ | {{ paramName }} |
+
+
+ {{ param.type }}
+
+ |
+ {{ param.description }} |
+
+
+
+
+
+
mdi-code-brackets
+
{{ tm('functionTools.noParameters') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ tm('dialogs.serverDetail.buttons.close') }}
+
+
+
+
+
@@ -537,6 +522,7 @@ export default {
showMcpServerDialog: false,
showServerDetailDialog: false,
addServerDialogMessage: "",
+ showToolsDialog: false,
showTools: true,
loading: false,
loadingGettingServers: false,
From e92fbb04431a98708b673321db39314305ab02bb Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Mon, 21 Jul 2025 15:05:49 +0800
Subject: [PATCH 4/7] feat: add ProxySelector component for GitHub proxy
configuration and connection testing (#2185)
---
astrbot/dashboard/routes/stat.py | 45 +++++-
.../src/components/shared/ProxySelector.vue | 152 ++++++++++++++++++
dashboard/src/views/ExtensionPage.vue | 4 +
dashboard/src/views/Settings.vue | 34 +---
4 files changed, 200 insertions(+), 35 deletions(-)
create mode 100644 dashboard/src/components/shared/ProxySelector.vue
diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py
index 79397290..2a838939 100644
--- a/astrbot/dashboard/routes/stat.py
+++ b/astrbot/dashboard/routes/stat.py
@@ -2,6 +2,7 @@ import traceback
import psutil
import time
import threading
+import aiohttp
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
@@ -25,6 +26,7 @@ class StatRoute(Route):
"/stat/version": ("GET", self.get_version),
"/stat/start-time": ("GET", self.get_start_time),
"/stat/restart-core": ("POST", self.restart_core),
+ "/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
}
self.db_helper = db_helper
self.register_routes()
@@ -45,11 +47,7 @@ class StatRoute(Route):
"""将总秒数转换为时分秒组件"""
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
- return {
- "hours": hours,
- "minutes": minutes,
- "seconds": seconds
- }
+ return {"hours": hours, "minutes": minutes, "seconds": seconds}
def is_default_cred(self):
username = self.config["dashboard"]["username"]
@@ -144,3 +142,40 @@ class StatRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
+
+ async def test_ghproxy_connection(self):
+ """
+ 测试 GitHub 代理连接是否可用。
+ """
+ try:
+ data = await request.get_json()
+ proxy_url: str = data.get("proxy_url")
+
+ if not proxy_url:
+ return Response().error("proxy_url is required").__dict__
+
+ proxy_url = proxy_url.rstrip("/")
+
+ test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
+ start_time = time.time()
+
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ test_url, timeout=aiohttp.ClientTimeout(total=10)
+ ) as response:
+ if response.status == 200:
+ end_time = time.time()
+ _ = await response.text()
+ ret = {
+ "latency": round((end_time - start_time) * 1000, 2),
+ }
+ return Response().ok(data=ret).__dict__
+ else:
+ return (
+ Response()
+ .error(f"Failed. Status code: {response.status}")
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"Error: {str(e)}").__dict__
diff --git a/dashboard/src/components/shared/ProxySelector.vue b/dashboard/src/components/shared/ProxySelector.vue
new file mode 100644
index 00000000..d45a0f52
--- /dev/null
+++ b/dashboard/src/components/shared/ProxySelector.vue
@@ -0,0 +1,152 @@
+
+ GitHub 加速
+
+
+
+
+ 使用 GitHub 加速
+
+ 测试代理连通性
+
+
+
+
+
+
+
+
+
+
{{ proxy }}
+
+
+ {{ proxyStatus[idx].available ? '可用' : '不可用' }}
+
+
+ {{ proxyStatus[idx].latency }}ms
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue
index 892034ad..c8fb6245 100644
--- a/dashboard/src/views/ExtensionPage.vue
+++ b/dashboard/src/views/ExtensionPage.vue
@@ -3,6 +3,7 @@ import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
+import ProxySelector from '@/components/shared/ProxySelector.vue';
import axios from 'axios';
import { useCommonStore } from '@/stores/common';
import { useI18n, useModuleI18n } from '@/i18n/composables';
@@ -1190,6 +1191,9 @@ onMounted(async () => {
hide-details
placeholder="https://github.com/username/repo"
>
+
diff --git a/dashboard/src/views/Settings.vue b/dashboard/src/views/Settings.vue
index 0b68ab7d..77ad4ea1 100644
--- a/dashboard/src/views/Settings.vue
+++ b/dashboard/src/views/Settings.vue
@@ -5,11 +5,8 @@
{{ tm('network.title') }}
-
-
-
-
+
+
{{ tm('system.title') }}
@@ -17,41 +14,29 @@
{{ tm('system.restart.button') }}
-
-
-
-
-
\ No newline at end of file
From 3ccbef141ee2bfe6fa7257a38aedf53028b65bd5 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Mon, 21 Jul 2025 15:16:49 +0800
Subject: [PATCH 5/7] perf: extension ui
---
astrbot/core/star/star_manager.py | 2 +-
dashboard/src/views/ExtensionPage.vue | 161 +++++++++++---------------
2 files changed, 68 insertions(+), 95 deletions(-)
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 4a6d4d90..781a5141 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -163,7 +163,7 @@ class PluginManager:
plugins.extend(_p)
return plugins
- async def _check_plugin_dept_update(self, target_plugin: str = None):
+ async def _check_plugin_dept_update(self, target_plugin: str | None = None):
"""检查插件的依赖
如果 target_plugin 为 None,则检查所有插件的依赖
"""
diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue
index c8fb6245..a3dbc9c3 100644
--- a/dashboard/src/views/ExtensionPage.vue
+++ b/dashboard/src/views/ExtensionPage.vue
@@ -30,12 +30,12 @@ const extension_config = reactive({
config: {}
});
const pluginMarketData = ref([]);
- const loadingDialog = reactive({
- show: false,
- title: "",
- statusCode: 0, // 0: loading, 1: success, 2: error,
- result: ""
- });
+const loadingDialog = reactive({
+ show: false,
+ title: "",
+ statusCode: 0, // 0: loading, 1: success, 2: error,
+ result: ""
+});
const showPluginInfoDialog = ref(false);
const selectedPlugin = ref({});
const curr_namespace = ref("");
@@ -185,8 +185,8 @@ const checkUpdate = () => {
if (matchedPlugin) {
extension.online_version = matchedPlugin.version;
- extension.has_update = extension.version !== matchedPlugin.version &&
- matchedPlugin.version !== tm('status.unknown');
+ extension.has_update = extension.version !== matchedPlugin.version &&
+ matchedPlugin.version !== tm('status.unknown');
} else {
extension.has_update = false;
}
@@ -623,27 +623,12 @@ onMounted(async () => {
-
+
-
+
@@ -679,33 +664,32 @@ onMounted(async () => {
mdi-plus
{{ tm('buttons.install') }}
-
-
-
-
-
- mdi-alert-circle
-
-
-
-
-
-
- mdi-alert-circle
- {{ tm('dialogs.error.title') }}
-
-
- {{ extension_data.message }}
- {{ tm('dialogs.error.checkConsole') }}
-
-
-
- {{ tm('buttons.close') }}
-
-
-
-
+
+
+
+
+ mdi-alert-circle
+
+
+
+
+
+ mdi-alert-circle
+ {{ tm('dialogs.error.title') }}
+
+
+ {{ extension_data.message }}
+ {{ tm('dialogs.error.checkConsole') }}
+
+
+
+ {{ tm('buttons.close') }}
+
+
+
+
+
@@ -727,7 +711,8 @@ onMounted(async () => {
{{ item.name }}
- {{ tm('status.system') }}
+ {{ tm('status.system')
+ }}
@@ -848,8 +833,8 @@ onMounted(async () => {
-
+
@@ -866,8 +851,8 @@ onMounted(async () => {
{{ tm('market.allPlugins') }}
-
+
@@ -905,7 +890,8 @@ onMounted(async () => {
-
-
+
{{ tag }}
@@ -959,7 +945,8 @@ onMounted(async () => {
{{ tm('dialogs.platformConfig.noAdapters') }}
{{ tm('dialogs.platformConfig.noAdaptersDesc') }}
- {{ tm('dialogs.platformConfig.goPlatforms') }}
+ {{ tm('dialogs.platformConfig.goPlatforms')
+ }}
@@ -994,7 +981,7 @@ onMounted(async () => {
-
+
@@ -1003,7 +990,8 @@ onMounted(async () => {
{{ plugin.name }}
- {{ tm('status.system') }}
+ {{ tm('status.system')
+ }}
{{ plugin.desc }}
|
@@ -1019,8 +1007,8 @@ onMounted(async () => {
{{ tm('buttons.close') }}
- {{ tm('buttons.save') }}
+ {{
+ tm('buttons.save') }}
@@ -1059,12 +1047,13 @@ onMounted(async () => {
{{ tm('dialogs.loading.logs') }}
-
+
+
-
+
-
+
{{ tm('buttons.close') }}
@@ -1100,7 +1089,8 @@ onMounted(async () => {
- {{ tm('buttons.close') }}
+ {{ tm('buttons.close')
+ }}
@@ -1147,25 +1137,13 @@ onMounted(async () => {
-
-
-
+
+
+
{{ tm('buttons.selectFile') }}
-
+
{{ tm('messages.supportedFormats') }}
@@ -1183,14 +1161,9 @@ onMounted(async () => {
-
+
From d7fd6164703e5da3865f1a8c892385cdb9f37826 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Mon, 21 Jul 2025 17:04:29 +0800
Subject: [PATCH 6/7] style: code quality
---
astrbot/core/star/__init__.py | 10 +++++-
astrbot/core/star/star.py | 8 +++--
astrbot/core/star/star_manager.py | 46 +++++++++++++++---------
astrbot/core/utils/io.py | 2 +-
astrbot/core/utils/shared_preferences.py | 4 ++-
packages/astrbot/main.py | 4 +++
6 files changed, 53 insertions(+), 21 deletions(-)
diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py
index 25c29165..86318f8b 100644
--- a/astrbot/core/star/__init__.py
+++ b/astrbot/core/star/__init__.py
@@ -10,7 +10,7 @@ from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
- def __init__(self, context: Context):
+ def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context)
self.context = context
@@ -41,9 +41,17 @@ class Star(CommandParserMixin):
tmpl, data, return_url=return_url, options=options
)
+ async def initialize(self):
+ """当插件被激活时会调用这个方法"""
+ pass
+
async def terminate(self):
"""当插件被禁用、重载插件时会调用这个方法"""
pass
+ def __del__(self):
+ """[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+ pass
+
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py
index d4438823..2fe9dd7f 100644
--- a/astrbot/core/star/star.py
+++ b/astrbot/core/star/star.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from types import ModuleType
+from typing import TYPE_CHECKING
from astrbot.core.config import AstrBotConfig
@@ -9,6 +10,9 @@ star_registry: list[StarMetadata] = []
star_map: dict[str, StarMetadata] = {}
"""key 是模块路径,__module__"""
+if TYPE_CHECKING:
+ from . import Star
+
@dataclass
class StarMetadata:
@@ -29,12 +33,12 @@ class StarMetadata:
repo: str | None = None
"""插件仓库地址"""
- star_cls_type: type | None = None
+ star_cls_type: type[Star] | None = None
"""插件的类对象的类型"""
module_path: str | None = None
"""插件的模块路径"""
- star_cls: object | None = None
+ star_cls: Star | None = None
"""插件的类对象"""
module: ModuleType | None = None
"""插件的模块对象"""
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 781a5141..b64b4aa8 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -187,7 +187,7 @@ class PluginManager:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
@staticmethod
- def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata:
+ def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
@@ -253,8 +253,8 @@ class PluginManager:
def _purge_modules(
self,
- module_patterns: list[str] = None,
- root_dir_name: str = None,
+ module_patterns: list[str] | None = None,
+ root_dir_name: str | None = None,
is_reserved: bool = False,
):
"""从 sys.modules 中移除指定的模块
@@ -314,8 +314,8 @@ class PluginManager:
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
-
- await self._unbind_plugin(smd.name, smd.module_path)
+ if smd.name and smd.module_path:
+ await self._unbind_plugin(smd.name, smd.module_path)
star_handlers_registry.clear()
star_map.clear()
@@ -331,8 +331,8 @@ class PluginManager:
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
-
- await self._unbind_plugin(smd.name, specified_module_path)
+ if smd.name:
+ await self._unbind_plugin(smd.name, specified_module_path)
result = await self.load(specified_module_path)
@@ -460,8 +460,7 @@ class PluginManager:
metadata.config = plugin_config
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
- if plugin_config:
- # metadata.config = plugin_config
+ if plugin_config and metadata.star_cls_type:
try:
metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config
@@ -470,7 +469,7 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
- else:
+ elif metadata.star_cls_type:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
@@ -487,6 +486,10 @@ class PluginManager:
)
metadata.update_platform_compatibility(plugin_enable_config)
+ assert metadata.module_path is not None, (
+ f"插件 {metadata.name} 的模块路径为空。"
+ )
+
# 绑定 handler
related_handlers = (
star_handlers_registry.get_handlers_by_module_name(
@@ -495,7 +498,8 @@ class PluginManager:
)
for handler in related_handlers:
handler.handler = functools.partial(
- handler.handler, metadata.star_cls
+ handler.handler,
+ metadata.star_cls, # type: ignore
)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
@@ -505,7 +509,8 @@ class PluginManager:
):
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(
- func_tool.handler, metadata.star_cls
+ func_tool.handler,
+ metadata.star_cls, # type: ignore
)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
@@ -532,13 +537,12 @@ class PluginManager:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
- else:
- logger.info(f"插件 {metadata.name} 已被禁用。")
- metadata = None
metadata = self._load_plugin_metadata(
plugin_path=plugin_dir_path, plugin_obj=obj
)
+ if not metadata:
+ raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
@@ -553,6 +557,10 @@ class PluginManager:
if metadata.module_path in inactivated_plugins:
metadata.activated = False
+ assert metadata.module_path is not None, (
+ f"插件 {metadata.name} 的模块路径为空。"
+ )
+
full_names = []
for handler in star_handlers_registry.get_handlers_by_module_name(
metadata.module_path
@@ -592,7 +600,7 @@ class PluginManager:
metadata.star_handler_full_names = full_names
# 执行 initialize() 方法
- if hasattr(metadata.star_cls, "initialize"):
+ if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
await metadata.star_cls.initialize()
except BaseException as e:
@@ -734,6 +742,9 @@ class PluginManager:
]:
del star_handlers_registry.star_handlers_map[k]
+ if plugin is None:
+ return
+
self._purge_modules(
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
)
@@ -795,6 +806,9 @@ class PluginManager:
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return
+ if star_metadata.star_cls is None:
+ return
+
if hasattr(star_metadata.star_cls, "__del__"):
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py
index 2cd8fd9c..2b34c2a1 100644
--- a/astrbot/core/utils/io.py
+++ b/astrbot/core/utils/io.py
@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
raise exc_info[1]
-def remove_dir(file_path) -> bool:
+def remove_dir(file_path: str) -> bool:
if not os.path.exists(file_path):
return True
shutil.rmtree(file_path, onerror=on_error)
diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py
index 7a503583..42018d19 100644
--- a/astrbot/core/utils/shared_preferences.py
+++ b/astrbot/core/utils/shared_preferences.py
@@ -1,7 +1,9 @@
import json
import os
+from typing import TypeVar
from .astrbot_path import get_astrbot_data_path
+_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
@@ -24,7 +26,7 @@ class SharedPreferences:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
- def get(self, key, default=None):
+ def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py
index fcb34e25..404d65f8 100644
--- a/packages/astrbot/main.py
+++ b/packages/astrbot/main.py
@@ -1242,6 +1242,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
logger.error(traceback.format_exc())
logger.error(f"主动回复失败: {e}")
+ @filter.on_decorating_result()
+ async def decorate_result(self, event: AstrMessageEvent):
+ logger.debug("Decorating result for event: %s", event)
+
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
From 140ddc70e683bfc9ee89e4fd94c1d11244b93c35 Mon Sep 17 00:00:00 2001
From: Gao Jinzhe <2968474907@qq.com>
Date: Wed, 23 Jul 2025 00:37:29 +0800
Subject: [PATCH 7/7] =?UTF-8?q?feat:=20=E4=BD=BF=E7=94=A8=E4=BC=9A?=
=?UTF-8?q?=E8=AF=9D=E9=94=81=E4=BF=9D=E8=AF=81=E5=88=86=E6=AE=B5=E5=9B=9E?=
=?UTF-8?q?=E5=A4=8D=E6=97=B6=E7=9A=84=E6=B6=88=E6=81=AF=E5=8F=91=E9=80=81?=
=?UTF-8?q?=E9=A1=BA=E5=BA=8F=20(#2130)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* 优化分段消息发送逻辑,为分段消息添加消息队列
* 删除了不必要的代码
* style: code quality
* 将消息队列机制重构为会话锁机制
* perf: narrow the lock scope
* refactor: replace get_lock with async context manager for session locks
* refactor: optimize session lock management with defaultdict
---------
Co-authored-by: Soulter <905617992@qq.com>
Co-authored-by: Raven95676
---
astrbot/core/pipeline/respond/stage.py | 40 ++++++++++++++------------
astrbot/core/utils/session_lock.py | 29 +++++++++++++++++++
2 files changed, 50 insertions(+), 19 deletions(-)
create mode 100644 astrbot/core/utils/session_lock.py
diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py
index 50b43604..54ad1e63 100644
--- a/astrbot/core/pipeline/respond/stage.py
+++ b/astrbot/core/pipeline/respond/stage.py
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
+from astrbot.core.utils.session_lock import session_lock_manager
@register_stage
@@ -177,25 +178,26 @@ class RespondStage(Stage):
result.chain.remove(comp)
break
- for rcomp in record_comps:
- i = await self._calc_comp_interval(rcomp)
- await asyncio.sleep(i)
- try:
- await event.send(MessageChain([rcomp]))
- except Exception as e:
- logger.error(f"发送消息失败: {e} chain: {result.chain}")
- break
-
- # 分段回复
- for comp in non_record_comps:
- i = await self._calc_comp_interval(comp)
- await asyncio.sleep(i)
- try:
- await event.send(MessageChain([*decorated_comps, comp]))
- decorated_comps = [] # 清空已发送的装饰组件
- except Exception as e:
- logger.error(f"发送消息失败: {e} chain: {result.chain}")
- break
+ # leverage lock to guarentee the order of message sending among different events
+ async with session_lock_manager.acquire_lock(event.unified_msg_origin):
+ for rcomp in record_comps:
+ i = await self._calc_comp_interval(rcomp)
+ await asyncio.sleep(i)
+ try:
+ await event.send(MessageChain([rcomp]))
+ except Exception as e:
+ logger.error(f"发送消息失败: {e} chain: {result.chain}")
+ break
+ # 分段回复
+ for comp in non_record_comps:
+ i = await self._calc_comp_interval(comp)
+ await asyncio.sleep(i)
+ try:
+ await event.send(MessageChain([*decorated_comps, comp]))
+ decorated_comps = [] # 清空已发送的装饰组件
+ except Exception as e:
+ logger.error(f"发送消息失败: {e} chain: {result.chain}")
+ break
else:
for rcomp in record_comps:
try:
diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py
new file mode 100644
index 00000000..912d91e5
--- /dev/null
+++ b/astrbot/core/utils/session_lock.py
@@ -0,0 +1,29 @@
+import asyncio
+from collections import defaultdict
+from contextlib import asynccontextmanager
+
+
+class SessionLockManager:
+ def __init__(self):
+ self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
+ self._lock_count: dict[str, int] = defaultdict(int)
+ self._access_lock = asyncio.Lock()
+
+ @asynccontextmanager
+ async def acquire_lock(self, session_id: str):
+ async with self._access_lock:
+ lock = self._locks[session_id]
+ self._lock_count[session_id] += 1
+
+ try:
+ async with lock:
+ yield
+ finally:
+ async with self._access_lock:
+ self._lock_count[session_id] -= 1
+ if self._lock_count[session_id] == 0:
+ self._locks.pop(session_id, None)
+ self._lock_count.pop(session_id, None)
+
+
+session_lock_manager = SessionLockManager()