Merge branch 'dev'
This commit is contained in:
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -177,25 +178,26 @@ class RespondStage(Stage):
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
|
||||
@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
@@ -80,12 +146,10 @@ class FuncTool:
|
||||
if not self.mcp_client or not self.mcp_client.session:
|
||||
raise Exception(f"MCP client for {self.name} is not available")
|
||||
# 使用name属性而不是额外的mcp_tool_name
|
||||
if ":" in self.name:
|
||||
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||
actual_tool_name = self.name.split(":")[-1]
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
return await self.mcp_client.session.call_tool(self.name, args)
|
||||
actual_tool_name = (
|
||||
self.name.split(":")[-1] if ":" in self.name else self.name
|
||||
)
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
|
||||
@@ -100,6 +164,7 @@ class MCPClient:
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
@@ -112,17 +177,19 @@ class MCPClient:
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = mcp_server_config.copy()
|
||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||
cfg = cfg["mcpServers"][key_0]
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
@@ -130,11 +197,18 @@ class MCPClient:
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(self._streams_context)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
@@ -148,11 +222,19 @@ class MCPClient:
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(self._streams_context)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -172,7 +254,7 @@ class MCPClient:
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
@@ -180,19 +262,18 @@ class MCPClient:
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
|
||||
|
||||
class FuncCall:
|
||||
@@ -201,8 +282,6 @@ class FuncCall:
|
||||
"""内部加载的 func tools"""
|
||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_service_queue = asyncio.Queue()
|
||||
"""用于外部控制 MCP 服务的启停"""
|
||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
@@ -258,7 +337,7 @@ class FuncCall:
|
||||
return f
|
||||
return None
|
||||
|
||||
async def _init_mcp_clients(self) -> None:
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
{
|
||||
@@ -300,115 +379,64 @@ class FuncCall:
|
||||
)
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
async def mcp_service_selector(self):
|
||||
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||
|
||||
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||
|
||||
{"type": "init"} 初始化所有MCP客户端
|
||||
|
||||
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||
|
||||
{"type": "terminate"} 终止所有MCP客户端
|
||||
|
||||
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||
"""
|
||||
while True:
|
||||
data = await self.mcp_service_queue.get()
|
||||
if data["type"] == "init":
|
||||
if "name" in data:
|
||||
event = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(
|
||||
data["name"], data["cfg"], event
|
||||
)
|
||||
)
|
||||
self.mcp_client_event[data["name"]] = event
|
||||
else:
|
||||
await self._init_mcp_clients()
|
||||
elif data["type"] == "terminate":
|
||||
if "name" in data:
|
||||
# await self._terminate_mcp_client(data["name"])
|
||||
if data["name"] in self.mcp_client_event:
|
||||
self.mcp_client_event[data["name"]].set()
|
||||
self.mcp_client_event.pop(data["name"], None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (
|
||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
for name in self.mcp_client_dict.keys():
|
||||
# await self._terminate_mcp_client(name)
|
||||
# self.mcp_client_event[name].set()
|
||||
if name in self.mcp_client_event:
|
||||
self.mcp_client_event[name].set()
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self, name: str, cfg: dict, event: asyncio.Event
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
finally:
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
"""初始化单个MCP客户端"""
|
||||
try:
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||
]
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||
]
|
||||
|
||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||
for tool in mcp_client.tools:
|
||||
func_tool = FuncTool(
|
||||
name=tool.name,
|
||||
parameters=tool.inputSchema,
|
||||
description=tool.description,
|
||||
origin="mcp",
|
||||
mcp_server_name=name,
|
||||
mcp_client=mcp_client,
|
||||
)
|
||||
self.func_list.append(func_tool)
|
||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||
for tool in mcp_client.tools:
|
||||
func_tool = FuncTool(
|
||||
name=tool.name,
|
||||
parameters=tool.inputSchema,
|
||||
description=tool.description,
|
||||
origin="mcp",
|
||||
mcp_server_name=name,
|
||||
mcp_client=mcp_client,
|
||||
)
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -418,7 +446,7 @@ class FuncCall:
|
||||
await self.mcp_client_dict[name].cleanup()
|
||||
self.mcp_client_dict.pop(name)
|
||||
except Exception as e:
|
||||
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
@@ -427,6 +455,103 @@ class FuncCall:
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
if "url" in config:
|
||||
success, error_msg = await _quick_test_mcp_connection(config)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
try:
|
||||
logger.debug(f"testing MCP server connection with config: {config}")
|
||||
await mcp_client.connect_to_server(config, "test")
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
finally:
|
||||
logger.debug("Cleaning up MCP client after testing connection.")
|
||||
await mcp_client.cleanup()
|
||||
return tool_names
|
||||
|
||||
async def enable_mcp_server(
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
Raises:
|
||||
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str | None = None, timeout: float = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server by its name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout (int): Timeout.
|
||||
"""
|
||||
if name:
|
||||
if name not in self.mcp_client_event:
|
||||
return
|
||||
client = self.mcp_client_dict.get(name)
|
||||
self.mcp_client_event[name].set()
|
||||
if not client:
|
||||
return
|
||||
client_running_event = client.running_event
|
||||
try:
|
||||
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if f.origin != "mcp" or f.mcp_server_name != name
|
||||
]
|
||||
else:
|
||||
running_events = [
|
||||
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||
]
|
||||
for key, event in self.mcp_client_event.items():
|
||||
event.set()
|
||||
# waiting for all clients to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.clear()
|
||||
self.mcp_client_dict.clear()
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
|
||||
@@ -169,10 +169,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
||||
)
|
||||
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
@@ -422,7 +419,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = None
|
||||
|
||||
if getattr(self.inst_map[provider_id], "terminate", None):
|
||||
await self.inst_map[provider_id].terminate() # type: ignore
|
||||
await self.inst_map[provider_id].terminate() # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
||||
@@ -432,6 +429,8 @@ class ProviderManager:
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
await provider_inst.terminate() # type: ignore
|
||||
# 清理 MCP Client 连接
|
||||
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
||||
await provider_inst.terminate() # type: ignore
|
||||
try:
|
||||
await self.llm_tools.disable_mcp_server()
|
||||
except Exception:
|
||||
logger.error("Error while disabling MCP servers", exc_info=True)
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core.star.star_tools import StarTools
|
||||
class Star(CommandParserMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
def __init__(self, context: Context, config: dict | None = None):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
@@ -41,9 +41,17 @@ class Star(CommandParserMixin):
|
||||
tmpl, data, return_url=return_url, options=options
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""当插件被激活时会调用这个方法"""
|
||||
pass
|
||||
|
||||
async def terminate(self):
|
||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
@@ -9,6 +10,9 @@ star_registry: list[StarMetadata] = []
|
||||
star_map: dict[str, StarMetadata] = {}
|
||||
"""key 是模块路径,__module__"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Star
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
@@ -29,12 +33,12 @@ class StarMetadata:
|
||||
repo: str | None = None
|
||||
"""插件仓库地址"""
|
||||
|
||||
star_cls_type: type | None = None
|
||||
star_cls_type: type[Star] | None = None
|
||||
"""插件的类对象的类型"""
|
||||
module_path: str | None = None
|
||||
"""插件的模块路径"""
|
||||
|
||||
star_cls: object | None = None
|
||||
star_cls: Star | None = None
|
||||
"""插件的类对象"""
|
||||
module: ModuleType | None = None
|
||||
"""插件的模块对象"""
|
||||
|
||||
@@ -163,7 +163,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str | None = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -187,7 +187,7 @@ class PluginManager:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata:
|
||||
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
|
||||
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
|
||||
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
|
||||
@@ -253,8 +253,8 @@ class PluginManager:
|
||||
|
||||
def _purge_modules(
|
||||
self,
|
||||
module_patterns: list[str] = None,
|
||||
root_dir_name: str = None,
|
||||
module_patterns: list[str] | None = None,
|
||||
root_dir_name: str | None = None,
|
||||
is_reserved: bool = False,
|
||||
):
|
||||
"""从 sys.modules 中移除指定的模块
|
||||
@@ -314,8 +314,8 @@ class PluginManager:
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
if smd.name and smd.module_path:
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
@@ -331,8 +331,8 @@ class PluginManager:
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
if smd.name:
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
@@ -460,8 +460,7 @@ class PluginManager:
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
# metadata.config = plugin_config
|
||||
if plugin_config and metadata.star_cls_type:
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -470,7 +469,7 @@ class PluginManager:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
else:
|
||||
elif metadata.star_cls_type:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
@@ -487,6 +486,10 @@ class PluginManager:
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = (
|
||||
star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -495,7 +498,8 @@ class PluginManager:
|
||||
)
|
||||
for handler in related_handlers:
|
||||
handler.handler = functools.partial(
|
||||
handler.handler, metadata.star_cls
|
||||
handler.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
@@ -505,7 +509,8 @@ class PluginManager:
|
||||
):
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(
|
||||
func_tool.handler, metadata.star_cls
|
||||
func_tool.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
@@ -532,13 +537,12 @@ class PluginManager:
|
||||
obj = getattr(module, classes[0])(
|
||||
context=self.context
|
||||
) # 实例化插件类
|
||||
else:
|
||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||
|
||||
metadata = None
|
||||
metadata = self._load_plugin_metadata(
|
||||
plugin_path=plugin_dir_path, plugin_obj=obj
|
||||
)
|
||||
if not metadata:
|
||||
raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
@@ -553,6 +557,10 @@ class PluginManager:
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
full_names = []
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||
metadata.module_path
|
||||
@@ -592,7 +600,7 @@ class PluginManager:
|
||||
metadata.star_handler_full_names = full_names
|
||||
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
@@ -734,6 +742,9 @@ class PluginManager:
|
||||
]:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
if plugin is None:
|
||||
return
|
||||
|
||||
self._purge_modules(
|
||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||
)
|
||||
@@ -795,6 +806,9 @@ class PluginManager:
|
||||
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
||||
return
|
||||
|
||||
if star_metadata.star_cls is None:
|
||||
return
|
||||
|
||||
if hasattr(star_metadata.star_cls, "__del__"):
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
|
||||
@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
|
||||
raise exc_info[1]
|
||||
|
||||
|
||||
def remove_dir(file_path) -> bool:
|
||||
def remove_dir(file_path: str) -> bool:
|
||||
if not os.path.exists(file_path):
|
||||
return True
|
||||
shutil.rmtree(file_path, onerror=on_error)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
class SessionLockManager:
|
||||
def __init__(self):
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._lock_count: dict[str, int] = defaultdict(int)
|
||||
self._access_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_lock(self, session_id: str):
|
||||
async with self._access_lock:
|
||||
lock = self._locks[session_id]
|
||||
self._lock_count[session_id] += 1
|
||||
|
||||
try:
|
||||
async with lock:
|
||||
yield
|
||||
finally:
|
||||
async with self._access_lock:
|
||||
self._lock_count[session_id] -= 1
|
||||
if self._lock_count[session_id] == 0:
|
||||
self._locks.pop(session_id, None)
|
||||
self._lock_count.pop(session_id, None)
|
||||
|
||||
|
||||
session_lock_manager = SessionLockManager()
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
@@ -24,7 +26,7 @@ class SharedPreferences:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
|
||||
@@ -2,6 +2,7 @@ import traceback
|
||||
import psutil
|
||||
import time
|
||||
import threading
|
||||
import aiohttp
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
@@ -25,6 +26,7 @@ class StatRoute(Route):
|
||||
"/stat/version": ("GET", self.get_version),
|
||||
"/stat/start-time": ("GET", self.get_start_time),
|
||||
"/stat/restart-core": ("POST", self.restart_core),
|
||||
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.register_routes()
|
||||
@@ -45,11 +47,7 @@ class StatRoute(Route):
|
||||
"""将总秒数转换为时分秒组件"""
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return {
|
||||
"hours": hours,
|
||||
"minutes": minutes,
|
||||
"seconds": seconds
|
||||
}
|
||||
return {"hours": hours, "minutes": minutes, "seconds": seconds}
|
||||
|
||||
def is_default_cred(self):
|
||||
username = self.config["dashboard"]["username"]
|
||||
@@ -144,3 +142,40 @@ class StatRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def test_ghproxy_connection(self):
|
||||
"""
|
||||
测试 GitHub 代理连接是否可用。
|
||||
"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
proxy_url: str = data.get("proxy_url")
|
||||
|
||||
if not proxy_url:
|
||||
return Response().error("proxy_url is required").__dict__
|
||||
|
||||
proxy_url = proxy_url.rstrip("/")
|
||||
|
||||
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
|
||||
start_time = time.time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
test_url, timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
end_time = time.time()
|
||||
_ = await response.text()
|
||||
ret = {
|
||||
"latency": round((end_time - start_time) * 1000, 2),
|
||||
}
|
||||
return Response().ok(data=ret).__dict__
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed. Status code: {response.status}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {str(e)}").__dict__
|
||||
|
||||
@@ -26,6 +26,7 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/update": ("POST", self.update_mcp_server),
|
||||
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
|
||||
"/tools/mcp/market": ("GET", self.get_mcp_markets),
|
||||
"/tools/mcp/test": ("POST", self.test_mcp_connection),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
@@ -132,12 +133,19 @@ class ToolsRoute(Route):
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 动态初始化新MCP客户端
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, server_config, timeout=30
|
||||
)
|
||||
except TimeoutError:
|
||||
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
return Response().error("保存配置失败").__dict__
|
||||
@@ -193,31 +201,55 @@ class ToolsRoute(Route):
|
||||
if self.save_mcp_config(config):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
# 如果要激活服务器或者配置已更改
|
||||
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
else:
|
||||
# 客户端不存在,初始化
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError as e:
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, config["mcpServers"][name], timeout=30
|
||||
)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
# 如果要停用服务器
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
self.tool_mgr.mcp_service_queue.put_nowait({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 超时。")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
@@ -239,17 +271,23 @@ class ToolsRoute(Route):
|
||||
if name not in config["mcpServers"]:
|
||||
return Response().error(f"服务器 {name} 不存在").__dict__
|
||||
|
||||
# 删除服务器配置
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 关闭并删除MCP客户端
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
self.tool_mgr.mcp_service_queue.put_nowait({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
return Response().error("保存配置失败").__dict__
|
||||
@@ -281,3 +319,20 @@ class ToolsRoute(Route):
|
||||
except Exception as _:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error("获取市场数据失败").__dict__
|
||||
|
||||
async def test_mcp_connection(self):
|
||||
"""
|
||||
测试 MCP 服务器连接
|
||||
"""
|
||||
try:
|
||||
server_data = await request.json
|
||||
config = server_data.get("mcp_server_config", None)
|
||||
|
||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||
return (
|
||||
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
|
||||
|
||||
Reference in New Issue
Block a user