Compare commits
5 Commits
refactor/s
...
feat/astrb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdbed75ce4 | ||
|
|
9fec29c1a3 | ||
|
|
972b5ffb86 | ||
|
|
33e67bf925 | ||
|
|
185501d1b5 |
@@ -151,11 +151,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 如果有工具调用,还需处理工具调用
|
# 如果有工具调用,还需处理工具调用
|
||||||
if llm_resp.tools_call_name:
|
if llm_resp.tools_call_name:
|
||||||
tool_call_result_blocks = []
|
tool_call_result_blocks = []
|
||||||
for tool_call_name in llm_resp.tools_call_name:
|
for tool_call_name, tool_call_id in zip(
|
||||||
|
llm_resp.tools_call_name, llm_resp.tools_call_ids
|
||||||
|
):
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="tool_call",
|
type="tool_call",
|
||||||
data=AgentResponseData(
|
data=AgentResponseData(
|
||||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
chain=MessageChain().message(f"🔨 正在使用工具: {tool_call_name} ({tool_call_id})")
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||||
@@ -255,63 +257,48 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
async for resp in executor: # type: ignore
|
async for resp in executor: # type: ignore
|
||||||
if isinstance(resp, CallToolResult):
|
if isinstance(resp, CallToolResult):
|
||||||
res = resp
|
res = resp
|
||||||
_final_resp = resp
|
content = res.content
|
||||||
if isinstance(res.content[0], TextContent):
|
|
||||||
tool_call_result_blocks.append(
|
aggr_text_content = ""
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
for cont in content:
|
||||||
tool_call_id=func_tool_id,
|
if isinstance(cont, TextContent):
|
||||||
content=res.content[0].text,
|
aggr_text_content += cont.text
|
||||||
)
|
yield MessageChain().message(cont.text)
|
||||||
)
|
elif isinstance(cont, ImageContent):
|
||||||
yield MessageChain().message(res.content[0].text)
|
aggr_text_content += "\n返回了图片(已直接发送给用户)\n"
|
||||||
elif isinstance(res.content[0], ImageContent):
|
|
||||||
tool_call_result_blocks.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content="返回了图片(已直接发送给用户)",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield MessageChain(type="tool_direct_result").base64_image(
|
|
||||||
res.content[0].data
|
|
||||||
)
|
|
||||||
elif isinstance(res.content[0], EmbeddedResource):
|
|
||||||
resource = res.content[0].resource
|
|
||||||
if isinstance(resource, TextResourceContents):
|
|
||||||
tool_call_result_blocks.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content=resource.text,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield MessageChain().message(resource.text)
|
|
||||||
elif (
|
|
||||||
isinstance(resource, BlobResourceContents)
|
|
||||||
and resource.mimeType
|
|
||||||
and resource.mimeType.startswith("image/")
|
|
||||||
):
|
|
||||||
tool_call_result_blocks.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content="返回了图片(已直接发送给用户)",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
type="tool_direct_result"
|
type="tool_direct_result"
|
||||||
).base64_image(resource.blob)
|
).base64_image(cont.data)
|
||||||
else:
|
elif isinstance(cont, EmbeddedResource):
|
||||||
tool_call_result_blocks.append(
|
resource = cont.resource
|
||||||
ToolCallMessageSegment(
|
if isinstance(resource, TextResourceContents):
|
||||||
role="tool",
|
aggr_text_content += resource.text
|
||||||
tool_call_id=func_tool_id,
|
yield MessageChain().message(resource.text)
|
||||||
content="返回的数据类型不受支持",
|
elif (
|
||||||
|
isinstance(resource, BlobResourceContents)
|
||||||
|
and resource.mimeType
|
||||||
|
and resource.mimeType.startswith("image/")
|
||||||
|
):
|
||||||
|
aggr_text_content += (
|
||||||
|
"\n返回了图片(已直接发送给用户)\n"
|
||||||
|
)
|
||||||
|
yield MessageChain(
|
||||||
|
type="tool_direct_result"
|
||||||
|
).base64_image(resource.blob)
|
||||||
|
else:
|
||||||
|
aggr_text_content += "\n返回的数据类型不受支持。\n"
|
||||||
|
yield MessageChain().message(
|
||||||
|
"返回的数据类型不受支持。"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
yield MessageChain().message("返回的数据类型不受支持。")
|
|
||||||
|
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=aggr_text_content,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
# 这里我们将直接结束 Agent Loop。
|
# 这里我们将直接结束 Agent Loop。
|
||||||
|
|||||||
@@ -813,7 +813,8 @@ class File(BaseMessageComponent):
|
|||||||
"""下载文件"""
|
"""下载文件"""
|
||||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
os.makedirs(download_dir, exist_ok=True)
|
os.makedirs(download_dir, exist_ok=True)
|
||||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
fname = self.name if self.name else uuid.uuid4().hex
|
||||||
|
file_path = os.path.join(download_dir, fname)
|
||||||
await download_file(self.url, file_path)
|
await download_file(self.url, file_path)
|
||||||
self.file_ = os.path.abspath(file_path)
|
self.file_ = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
|||||||
@@ -232,7 +232,9 @@ class AiocqhttpAdapter(Platform):
|
|||||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||||
# Lagrange
|
# Lagrange
|
||||||
logger.info("guessing lagrange")
|
logger.info("guessing lagrange")
|
||||||
file_name = m["data"].get("file_name", "file")
|
file_name = m["data"].get(
|
||||||
|
"file_name", m["data"].get("file", "file")
|
||||||
|
)
|
||||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from astrbot.core.provider.provider import (
|
|||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.manager import ProviderManager
|
from astrbot.core.provider.manager import ProviderManager
|
||||||
@@ -258,6 +258,11 @@ class Context:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def add_llm_tool(self, *tools: FunctionTool) -> None:
|
||||||
|
"""添加一个 LLM 工具。"""
|
||||||
|
for tool in tools:
|
||||||
|
self.provider_manager.llm_tools.func_list.append(tool)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||||
"""
|
"""
|
||||||
|
|||||||
10
packages/astrbot_agent/_conf_schema.json
Normal file
10
packages/astrbot_agent/_conf_schema.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"endpoint": {
|
||||||
|
"description": "The endpoint URL of the Shipyard server.",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"access_token": {
|
||||||
|
"description": "The access token for authenticating with the Shipyard server.",
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
129
packages/astrbot_agent/commands/file.py
Normal file
129
packages/astrbot_agent/commands/file.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import os
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.all import Context
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.core.utils.session_waiter import (
|
||||||
|
session_waiter,
|
||||||
|
SessionController,
|
||||||
|
)
|
||||||
|
from ..sandbox_client import SandboxClient
|
||||||
|
|
||||||
|
|
||||||
|
class FileCommand:
|
||||||
|
def __init__(self, context: Context) -> None:
|
||||||
|
self.context = context
|
||||||
|
self.user_file_uploads: dict[str, list[str]] = {} # umo -> file_path
|
||||||
|
self.user_file_uploaded_files: dict[str, list[str]] = {} # umo -> file_path
|
||||||
|
"""记录用户上传过的文件,保存了文件在沙箱中的路径。这个在用户下一次 LLM 请求时会被用到,然后清空。"""
|
||||||
|
|
||||||
|
async def _upload_file_to_sandbox(self, event: AstrMessageEvent) -> list[str]:
|
||||||
|
"""将用户上传的文件上传到沙箱"""
|
||||||
|
sender_id = event.get_sender_id()
|
||||||
|
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||||
|
fpath_ls = self.user_file_uploads[sender_id]
|
||||||
|
errors = []
|
||||||
|
for path in fpath_ls:
|
||||||
|
try:
|
||||||
|
fname = os.path.basename(path)
|
||||||
|
data = await sb.upload_file(path, fname)
|
||||||
|
success = data.get("success", False)
|
||||||
|
if not success:
|
||||||
|
raise Exception(f"Upload failed: {data}")
|
||||||
|
file_path = data.get("file_path", "")
|
||||||
|
logger.info(f"File {fname} uploaded to sandbox at {file_path}")
|
||||||
|
self.user_file_uploaded_files.setdefault(sender_id, []).append(
|
||||||
|
file_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append((path, str(e)))
|
||||||
|
logger.error(f"Error uploading file {path}: {e}")
|
||||||
|
|
||||||
|
# clean up files
|
||||||
|
for path in fpath_ls:
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error removing temp file {path}: {e}")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
async def file(self, event: AstrMessageEvent):
|
||||||
|
"""等待用户上传文件或图片"""
|
||||||
|
await event.send(
|
||||||
|
MessageChain().message(
|
||||||
|
f"请上传一个或多个文件(或图片),使用 /endupload 结束上传。(请求者 ID: {event.get_sender_id()})"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
|
||||||
|
@session_waiter(timeout=600, record_history_chains=False) # type: ignore
|
||||||
|
async def empty_mention_waiter(
|
||||||
|
controller: SessionController, event: AstrMessageEvent
|
||||||
|
):
|
||||||
|
idiom = event.message_str
|
||||||
|
sender_id = event.get_sender_id()
|
||||||
|
|
||||||
|
if idiom == "endupload":
|
||||||
|
files = self.user_file_uploads.get(sender_id, [])
|
||||||
|
if not files:
|
||||||
|
await event.send(
|
||||||
|
event.plain_result("你没有上传任何文件,上传已取消。")
|
||||||
|
)
|
||||||
|
controller.stop()
|
||||||
|
return
|
||||||
|
await event.send(
|
||||||
|
event.plain_result(f"开始上传 {len(files)} 个文件到沙箱...")
|
||||||
|
)
|
||||||
|
errors = await self._upload_file_to_sandbox(event)
|
||||||
|
if errors:
|
||||||
|
error_msgs = "\n".join(
|
||||||
|
[f"{path}: {err}" for path, err in errors]
|
||||||
|
)
|
||||||
|
await event.send(
|
||||||
|
event.plain_result(
|
||||||
|
f"上传中出现错误:\n{error_msgs}\n其他文件已成功上传。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await event.send(
|
||||||
|
event.plain_result(
|
||||||
|
f"上传完毕,共上传 {len(files)} 个文件。文件信息已被保存,下一次 LLM 请求时会自动将信息附上。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.user_file_uploads.pop(sender_id, None)
|
||||||
|
controller.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解析文件或图片消息
|
||||||
|
for comp in event.message_obj.message:
|
||||||
|
if isinstance(comp, (Comp.File, Comp.Image)):
|
||||||
|
if isinstance(comp, Comp.File):
|
||||||
|
path = await comp.get_file()
|
||||||
|
self.user_file_uploads.setdefault(
|
||||||
|
event.get_sender_id(), []
|
||||||
|
).append(path)
|
||||||
|
elif isinstance(comp, Comp.Image):
|
||||||
|
path = await comp.convert_to_file_path()
|
||||||
|
self.user_file_uploads.setdefault(
|
||||||
|
event.get_sender_id(), []
|
||||||
|
).append(path)
|
||||||
|
fname = os.path.basename(path)
|
||||||
|
await event.send(
|
||||||
|
event.plain_result(
|
||||||
|
f"已接收文件: {fname},继续上传或发送 /endupload 结束。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await empty_mention_waiter(event)
|
||||||
|
except TimeoutError as _:
|
||||||
|
await event.send(event.plain_result("等待上传超时,上传已取消。"))
|
||||||
|
except Exception as e:
|
||||||
|
await event.send(
|
||||||
|
event.plain_result("发生错误,请联系管理员: " + str(e))
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
event.stop_event()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("handle_empty_mention error: " + str(e))
|
||||||
46
packages/astrbot_agent/main.py
Normal file
46
packages/astrbot_agent/main.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
import astrbot.api.star as star
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import filter, AstrMessageEvent
|
||||||
|
from astrbot.api.provider import ProviderRequest
|
||||||
|
from astrbot.api import AstrBotConfig
|
||||||
|
from .tools.fs import CreateFileTool, ReadFileTool
|
||||||
|
from .tools.shell import ExecuteShellTool
|
||||||
|
from .tools.python import PythonTool
|
||||||
|
from .commands.file import FileCommand
|
||||||
|
|
||||||
|
|
||||||
|
class Main(star.Star):
|
||||||
|
"""AstrBot Agent"""
|
||||||
|
|
||||||
|
def __init__(self, context: star.Context, config: AstrBotConfig) -> None:
|
||||||
|
self.context = context
|
||||||
|
self.config = config
|
||||||
|
self.endpoint = config.get("endpoint", "http://localhost:8000")
|
||||||
|
self.access_token = config.get("access_token", "")
|
||||||
|
os.environ["SHIPYARD_ENDPOINT"] = self.endpoint
|
||||||
|
os.environ["SHIPYARD_ACCESS_TOKEN"] = self.access_token
|
||||||
|
|
||||||
|
context.add_llm_tool(
|
||||||
|
CreateFileTool(), ExecuteShellTool(), PythonTool(), ReadFileTool()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.file_c = FileCommand(context)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@filter.command("fileupload")
|
||||||
|
async def fileupload(self, event: AstrMessageEvent):
|
||||||
|
"""处理文件上传"""
|
||||||
|
await self.file_c.file(event)
|
||||||
|
|
||||||
|
@filter.on_llm_request()
|
||||||
|
async def on_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||||
|
"""处理 LLM 请求"""
|
||||||
|
sender_id = event.get_sender_id()
|
||||||
|
uploads = self.file_c.user_file_uploaded_files.pop(sender_id, None)
|
||||||
|
if uploads:
|
||||||
|
logger.info(f"Attaching uploaded files for user {sender_id}: {uploads}")
|
||||||
|
|
||||||
|
req.system_prompt = f"{req.system_prompt}\n\n\n# User Uploaded Files: {uploads}"
|
||||||
4
packages/astrbot_agent/metadata.yaml
Normal file
4
packages/astrbot_agent/metadata.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
name: astrbot-agent
|
||||||
|
desc: AstrBot Agent
|
||||||
|
author: Soulter
|
||||||
|
version: 0.0.1
|
||||||
40
packages/astrbot_agent/sandbox_client.py
Normal file
40
packages/astrbot_agent/sandbox_client.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from shipyard import ShipyardClient, SessionShip, Spec
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxClient:
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
session_ship: dict[str, SessionShip] = {}
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(SandboxClient, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not SandboxClient._initialized:
|
||||||
|
self.endpoint = os.getenv("SHIPYARD_ENDPOINT", "http://localhost:8000")
|
||||||
|
self.access_token = os.getenv("SHIPYARD_ACCESS_TOKEN", "")
|
||||||
|
self.client = ShipyardClient(
|
||||||
|
endpoint_url=self.endpoint, access_token=self.access_token
|
||||||
|
)
|
||||||
|
SandboxClient._initialized = True
|
||||||
|
|
||||||
|
async def get_ship(self, session_id: str) -> SessionShip:
|
||||||
|
if session_id not in self.session_ship:
|
||||||
|
uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex
|
||||||
|
ship = await self.client.create_ship(
|
||||||
|
ttl=3600,
|
||||||
|
spec=Spec(cpus=1.0, memory="512m"),
|
||||||
|
max_session_num=3,
|
||||||
|
session_id=uuid_str,
|
||||||
|
)
|
||||||
|
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
|
||||||
|
self.session_ship[session_id] = ship
|
||||||
|
return self.session_ship[session_id]
|
||||||
|
|
||||||
|
def get_client(self) -> ShipyardClient:
|
||||||
|
return self.client
|
||||||
60
packages/astrbot_agent/tools/fs.py
Normal file
60
packages/astrbot_agent/tools/fs.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import json
|
||||||
|
from astrbot.api import FunctionTool
|
||||||
|
from astrbot.api.event import AstrMessageEvent
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from ..sandbox_client import SandboxClient
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CreateFileTool(FunctionTool):
|
||||||
|
name: str = "astrbot_create_file"
|
||||||
|
description: str = "Create a new file in the sandbox."
|
||||||
|
parameters: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"path": "string",
|
||||||
|
"description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The content to write into the file.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path", "content"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, event: AstrMessageEvent, path: str, content: str):
|
||||||
|
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||||
|
try:
|
||||||
|
result = await sb.fs.create_file(path, content)
|
||||||
|
return json.dumps(result)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error creating file: {str(e)}"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReadFileTool(FunctionTool):
|
||||||
|
name: str = "astrbot_read_file"
|
||||||
|
description: str = "Read the content of a file in the sandbox."
|
||||||
|
parameters: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, event: AstrMessageEvent, path: str):
|
||||||
|
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||||
|
try:
|
||||||
|
result = await sb.fs.read_file(path)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error reading file: {str(e)}"
|
||||||
53
packages/astrbot_agent/tools/python.py
Normal file
53
packages/astrbot_agent/tools/python.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import mcp
|
||||||
|
from astrbot.api import FunctionTool
|
||||||
|
from astrbot.api.event import AstrMessageEvent
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from ..sandbox_client import SandboxClient
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PythonTool(FunctionTool):
|
||||||
|
name: str = "astrbot_execute_ipython"
|
||||||
|
description: str = "Execute a command in an IPython shell."
|
||||||
|
parameters: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The Python code to execute.",
|
||||||
|
},
|
||||||
|
"slient": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to suppress the output of the code execution.",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["code"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, event: AstrMessageEvent, code: str, silent: bool = False):
|
||||||
|
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||||
|
try:
|
||||||
|
result = await sb.python.exec(code, silent=silent)
|
||||||
|
output = result.get("output", {})
|
||||||
|
images: list[dict] = output.get("images", [])
|
||||||
|
text: str = output.get("text", "")
|
||||||
|
|
||||||
|
resp = mcp.types.CallToolResult(content=[])
|
||||||
|
|
||||||
|
if images:
|
||||||
|
for img in images:
|
||||||
|
resp.content.append(
|
||||||
|
mcp.types.ImageContent(
|
||||||
|
type="image", data=img["image/png"], mimeType="image/png"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
resp.content.append(mcp.types.TextContent(type="text", text=text))
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error executing code: {str(e)}"
|
||||||
48
packages/astrbot_agent/tools/shell.py
Normal file
48
packages/astrbot_agent/tools/shell.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import json
|
||||||
|
from astrbot.api import FunctionTool
|
||||||
|
from astrbot.api.event import AstrMessageEvent
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from ..sandbox_client import SandboxClient
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecuteShellTool(FunctionTool):
|
||||||
|
name: str = "astrbot_execute_shell"
|
||||||
|
description: str = "Execute a command in the shell."
|
||||||
|
parameters: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The shell command to execute.",
|
||||||
|
},
|
||||||
|
"background": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to run the command in the background.",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
"env": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Optional environment variables to set for the file creation process.",
|
||||||
|
"additionalProperties": {"type": "string"},
|
||||||
|
"default": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
command: str,
|
||||||
|
background: bool = False,
|
||||||
|
env: dict = {},
|
||||||
|
):
|
||||||
|
sb = await SandboxClient().get_ship(event.unified_msg_origin)
|
||||||
|
try:
|
||||||
|
result = await sb.shell.exec(command, background=background, env=env)
|
||||||
|
return json.dumps(result)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error executing command: {str(e)}"
|
||||||
@@ -1,519 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
import aiohttp
|
|
||||||
import uuid
|
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
import aiodocker
|
|
||||||
import time
|
|
||||||
import astrbot.api.star as star
|
|
||||||
from collections import defaultdict
|
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
|
||||||
from astrbot.api import llm_tool, logger
|
|
||||||
from astrbot.api.event import filter
|
|
||||||
from astrbot.api.provider import ProviderRequest
|
|
||||||
from astrbot.api.message_components import Image, File
|
|
||||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
||||||
|
|
||||||
PROMPT = """
|
|
||||||
## Task
|
|
||||||
You need to generate python codes to solve user's problem: {prompt}
|
|
||||||
|
|
||||||
{extra_input}
|
|
||||||
|
|
||||||
## Limit
|
|
||||||
1. Available libraries:
|
|
||||||
- standard libs
|
|
||||||
- `Pillow`
|
|
||||||
- `requests`
|
|
||||||
- `numpy`
|
|
||||||
- `matplotlib`
|
|
||||||
- `scipy`
|
|
||||||
- `scikit-learn`
|
|
||||||
- `beautifulsoup4`
|
|
||||||
- `pandas`
|
|
||||||
- `opencv-python`
|
|
||||||
- `python-docx`
|
|
||||||
- `python-pptx`
|
|
||||||
- `pymupdf` (Do not use fpdf, reportlab, etc.)
|
|
||||||
- `mplfonts`
|
|
||||||
You can only use these libraries and the libraries that they depend on.
|
|
||||||
2. Do not generate malicious code.
|
|
||||||
3. Use given `shared.api` package to output the result.
|
|
||||||
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
|
|
||||||
For Image and file, you must save it to `output` folder.
|
|
||||||
4. You must only output the code, do not output the result of the code and any other information.
|
|
||||||
5. The output language is same as user's input language.
|
|
||||||
6. Please first provide relevant knowledge about user's problem appropriately.
|
|
||||||
|
|
||||||
## Example
|
|
||||||
1. User's problem: `please solve the fabonacci sequence problem.`
|
|
||||||
Output:
|
|
||||||
```python
|
|
||||||
from shared.api import send_text, send_image, send_file
|
|
||||||
|
|
||||||
def fabonacci(n):
|
|
||||||
if n <= 1:
|
|
||||||
return n
|
|
||||||
else:
|
|
||||||
return fabonacci(n-1) + fabonacci(n-2)
|
|
||||||
|
|
||||||
result = fabonacci(10)
|
|
||||||
send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.")
|
|
||||||
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
|
|
||||||
```
|
|
||||||
|
|
||||||
2. User's problem: `please draw a sin(x) function.`
|
|
||||||
Output:
|
|
||||||
```python
|
|
||||||
from shared.api import send_text, send_image, send_file
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
x = np.linspace(0, 2*np.pi, 100)
|
|
||||||
y = np.sin(x)
|
|
||||||
plt.plot(x, y)
|
|
||||||
plt.savefig("output/sin_x.png")
|
|
||||||
send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).")
|
|
||||||
send_image("output/sin_x.png") # send_image is a function to send image to user
|
|
||||||
send_text("If you need more information, please let me know :)")
|
|
||||||
```
|
|
||||||
|
|
||||||
{extra_prompt}
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_CONFIG = {
|
|
||||||
"sandbox": {
|
|
||||||
"image": "soulter/astrbot-code-interpreter-sandbox",
|
|
||||||
"docker_mirror": "", # cjie.eu.org
|
|
||||||
},
|
|
||||||
"docker_host_astrbot_abs_path": "",
|
|
||||||
}
|
|
||||||
PATH = os.path.join(get_astrbot_data_path(), "config", "python_interpreter.json")
|
|
||||||
|
|
||||||
|
|
||||||
class Main(star.Star):
|
|
||||||
"""基于 Docker 沙箱的 Python 代码执行器"""
|
|
||||||
|
|
||||||
def __init__(self, context: star.Context) -> None:
|
|
||||||
self.context = context
|
|
||||||
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
self.shared_path = os.path.join("data", "py_interpreter_shared")
|
|
||||||
if not os.path.exists(self.shared_path):
|
|
||||||
# 复制 api.py 到 shared 目录
|
|
||||||
os.makedirs(self.shared_path, exist_ok=True)
|
|
||||||
shared_api_file = os.path.join(self.curr_dir, "shared", "api.py")
|
|
||||||
shutil.copy(shared_api_file, self.shared_path)
|
|
||||||
self.workplace_path = os.path.join("data", "py_interpreter_workplace")
|
|
||||||
os.makedirs(self.workplace_path, exist_ok=True)
|
|
||||||
|
|
||||||
self.user_file_msg_buffer = defaultdict(list)
|
|
||||||
"""存放用户上传的文件和图片"""
|
|
||||||
self.user_waiting = {}
|
|
||||||
"""正在等待用户的文件或图片"""
|
|
||||||
|
|
||||||
# 加载配置
|
|
||||||
if not os.path.exists(PATH):
|
|
||||||
self.config = DEFAULT_CONFIG
|
|
||||||
self._save_config()
|
|
||||||
else:
|
|
||||||
with open(PATH, "r") as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
ok = await self.is_docker_available()
|
|
||||||
if not ok:
|
|
||||||
logger.info(
|
|
||||||
"Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。"
|
|
||||||
)
|
|
||||||
# await self.context._star_manager.turn_off_plugin(
|
|
||||||
# "astrbot-python-interpreter"
|
|
||||||
# )
|
|
||||||
|
|
||||||
async def file_upload(self, file_path: str):
|
|
||||||
"""
|
|
||||||
上传图像文件到 S3
|
|
||||||
"""
|
|
||||||
ext = os.path.splitext(file_path)[1]
|
|
||||||
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
file = f.read()
|
|
||||||
|
|
||||||
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
headers={"Accept": "application/json"}, trust_env=True
|
|
||||||
) as session:
|
|
||||||
async with session.put(s3_file_url, data=file) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise Exception(f"Failed to upload image: {resp.status}")
|
|
||||||
return s3_file_url
|
|
||||||
|
|
||||||
async def is_docker_available(self) -> bool:
|
|
||||||
"""Check if docker is available"""
|
|
||||||
try:
|
|
||||||
docker = aiodocker.Docker()
|
|
||||||
await docker.version()
|
|
||||||
await docker.close()
|
|
||||||
return True
|
|
||||||
except BaseException as e:
|
|
||||||
logger.info(f"检查 Docker 可用性: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_image_name(self) -> str:
|
|
||||||
"""Get the image name"""
|
|
||||||
if self.config["sandbox"]["docker_mirror"]:
|
|
||||||
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
|
|
||||||
return self.config["sandbox"]["image"]
|
|
||||||
|
|
||||||
def _save_config(self):
|
|
||||||
with open(PATH, "w") as f:
|
|
||||||
json.dump(self.config, f)
|
|
||||||
|
|
||||||
async def gen_magic_code(self) -> str:
|
|
||||||
return uuid.uuid4().hex[:8]
|
|
||||||
|
|
||||||
async def download_image(
|
|
||||||
self, image_url: str, workplace_path: str, filename: str
|
|
||||||
) -> str:
|
|
||||||
"""Download image from url to workplace_path"""
|
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
||||||
async with session.get(image_url) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
return ""
|
|
||||||
image_path = os.path.join(workplace_path, f"{filename}.jpg")
|
|
||||||
with open(image_path, "wb") as f:
|
|
||||||
f.write(await resp.read())
|
|
||||||
return f"{filename}.jpg"
|
|
||||||
|
|
||||||
async def tidy_code(self, code: str) -> str:
|
|
||||||
"""Tidy the code"""
|
|
||||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
|
||||||
match = re.search(pattern, code, re.DOTALL)
|
|
||||||
if match is None:
|
|
||||||
raise ValueError("The code is not in the code block.")
|
|
||||||
return match.group(1)
|
|
||||||
|
|
||||||
@filter.event_message_type(filter.EventMessageType.ALL)
|
|
||||||
async def on_message(self, event: AstrMessageEvent):
|
|
||||||
"""处理消息"""
|
|
||||||
uid = event.get_sender_id()
|
|
||||||
if uid not in self.user_waiting:
|
|
||||||
return
|
|
||||||
for comp in event.message_obj.message:
|
|
||||||
if isinstance(comp, File):
|
|
||||||
file_path = await comp.get_file()
|
|
||||||
if file_path.startswith("http"):
|
|
||||||
name = comp.name if comp.name else uuid.uuid4().hex[:8]
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
path = os.path.join(temp_dir, name)
|
|
||||||
await download_file(file_path, path)
|
|
||||||
else:
|
|
||||||
path = file_path
|
|
||||||
self.user_file_msg_buffer[event.get_session_id()].append(path)
|
|
||||||
logger.debug(f"User {uid} uploaded file: {path}")
|
|
||||||
yield event.plain_result(f"代码执行器: 文件已经上传: {path}")
|
|
||||||
if uid in self.user_waiting:
|
|
||||||
del self.user_waiting[uid]
|
|
||||||
elif isinstance(comp, Image):
|
|
||||||
image_url = comp.url if comp.url else comp.file
|
|
||||||
if image_url.startswith("http"):
|
|
||||||
image_path = await download_image_by_url(image_url)
|
|
||||||
elif image_url.startswith("file:///"):
|
|
||||||
image_path = image_url.replace("file:///", "")
|
|
||||||
else:
|
|
||||||
image_path = image_url
|
|
||||||
self.user_file_msg_buffer[event.get_session_id()].append(image_path)
|
|
||||||
logger.debug(f"User {uid} uploaded image: {image_path}")
|
|
||||||
yield event.plain_result(f"代码执行器: 图片已经上传: {image_path}")
|
|
||||||
if uid in self.user_waiting:
|
|
||||||
del self.user_waiting[uid]
|
|
||||||
|
|
||||||
@filter.on_llm_request()
|
|
||||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
|
||||||
if event.get_session_id() in self.user_file_msg_buffer:
|
|
||||||
files = self.user_file_msg_buffer[event.get_session_id()]
|
|
||||||
request.prompt += f"\nUser provided files: {files}"
|
|
||||||
|
|
||||||
@filter.command_group("pi")
|
|
||||||
def pi(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@pi.command("absdir")
|
|
||||||
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
|
|
||||||
"""设置 Docker 宿主机绝对路径"""
|
|
||||||
if not path:
|
|
||||||
yield event.plain_result(
|
|
||||||
f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.config["docker_host_astrbot_abs_path"] = path
|
|
||||||
self._save_config()
|
|
||||||
yield event.plain_result(f"设置 Docker 宿主机绝对路径成功: {path}")
|
|
||||||
|
|
||||||
@pi.command("mirror")
|
|
||||||
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
|
|
||||||
"""Docker 镜像地址"""
|
|
||||||
if not url:
|
|
||||||
yield event.plain_result(f"""当前 Docker 镜像地址: {self.config["sandbox"]["docker_mirror"]}。
|
|
||||||
使用 `pi mirror <url>` 来设置 Docker 镜像地址。
|
|
||||||
您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。
|
|
||||||
""")
|
|
||||||
else:
|
|
||||||
self.config["sandbox"]["docker_mirror"] = url
|
|
||||||
self._save_config()
|
|
||||||
yield event.plain_result("设置 Docker 镜像地址成功。")
|
|
||||||
|
|
||||||
@pi.command("repull")
|
|
||||||
async def pi_repull(self, event: AstrMessageEvent):
|
|
||||||
"""重新拉取沙箱镜像"""
|
|
||||||
docker = aiodocker.Docker()
|
|
||||||
image_name = await self.get_image_name()
|
|
||||||
try:
|
|
||||||
await docker.images.get(image_name)
|
|
||||||
await docker.images.delete(image_name, force=True)
|
|
||||||
except aiodocker.exceptions.DockerError:
|
|
||||||
pass
|
|
||||||
await docker.images.pull(image_name)
|
|
||||||
yield event.plain_result("重新拉取沙箱镜像成功。")
|
|
||||||
|
|
||||||
@pi.command("file")
|
|
||||||
async def pi_file(self, event: AstrMessageEvent):
|
|
||||||
"""在规定秒数(60s)内上传一个文件"""
|
|
||||||
uid = event.get_sender_id()
|
|
||||||
self.user_waiting[uid] = time.time()
|
|
||||||
tip = "文件"
|
|
||||||
yield event.plain_result(f"代码执行器: 请在 60s 内上传一个{tip}。")
|
|
||||||
await asyncio.sleep(60)
|
|
||||||
if uid in self.user_waiting:
|
|
||||||
yield event.plain_result(
|
|
||||||
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。"
|
|
||||||
)
|
|
||||||
self.user_waiting.pop(uid)
|
|
||||||
|
|
||||||
@pi.command("clear", alias=["clean"])
|
|
||||||
async def pi_file_clean(self, event: AstrMessageEvent):
|
|
||||||
"""清理用户上传的文件"""
|
|
||||||
uid = event.get_sender_id()
|
|
||||||
if uid in self.user_waiting:
|
|
||||||
self.user_waiting.pop(uid)
|
|
||||||
yield event.plain_result(
|
|
||||||
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 已清理。"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield event.plain_result(
|
|
||||||
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有等待上传文件。"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pi.command("list")
|
|
||||||
async def pi_file_list(self, event: AstrMessageEvent):
|
|
||||||
"""列出用户上传的文件"""
|
|
||||||
uid = event.get_sender_id()
|
|
||||||
if uid in self.user_file_msg_buffer:
|
|
||||||
files = self.user_file_msg_buffer[uid]
|
|
||||||
yield event.plain_result(
|
|
||||||
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 上传的文件: {files}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield event.plain_result(
|
|
||||||
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有上传文件。"
|
|
||||||
)
|
|
||||||
|
|
||||||
@llm_tool("python_interpreter")
|
|
||||||
async def python_interpreter(self, event: AstrMessageEvent):
|
|
||||||
"""Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
|
|
||||||
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
|
|
||||||
"""
|
|
||||||
if not await self.is_docker_available():
|
|
||||||
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
|
|
||||||
|
|
||||||
plain_text = event.message_str
|
|
||||||
|
|
||||||
# 创建必要的工作目录和幻术码
|
|
||||||
magic_code = await self.gen_magic_code()
|
|
||||||
workplace_path = os.path.join(self.workplace_path, magic_code)
|
|
||||||
output_path = os.path.join(workplace_path, "output")
|
|
||||||
os.makedirs(workplace_path, exist_ok=True)
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
|
|
||||||
files = []
|
|
||||||
# 文件
|
|
||||||
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
|
|
||||||
if not file_path:
|
|
||||||
continue
|
|
||||||
elif not os.path.exists(file_path):
|
|
||||||
logger.warning(f"文件 {file_path} 不存在,已忽略。")
|
|
||||||
continue
|
|
||||||
# cp
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
shutil.copy(file_path, os.path.join(workplace_path, file_name))
|
|
||||||
files.append(file_name)
|
|
||||||
|
|
||||||
logger.debug(f"user query: {plain_text}, files: {files}")
|
|
||||||
|
|
||||||
# 整理额外输入
|
|
||||||
extra_inputs = ""
|
|
||||||
if files:
|
|
||||||
extra_inputs += f"User provided files: {files}\n"
|
|
||||||
|
|
||||||
obs = ""
|
|
||||||
n = 5
|
|
||||||
|
|
||||||
for i in range(n):
|
|
||||||
if i > 0:
|
|
||||||
logger.info(f"Try {i + 1}/{n}")
|
|
||||||
|
|
||||||
PROMPT_ = PROMPT.format(
|
|
||||||
prompt=plain_text,
|
|
||||||
extra_input=extra_inputs,
|
|
||||||
extra_prompt=obs,
|
|
||||||
)
|
|
||||||
provider = self.context.get_using_provider()
|
|
||||||
llm_response = await provider.text_chat(
|
|
||||||
prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"code interpreter llm gened code:" + llm_response.completion_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# 整理代码并保存
|
|
||||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
|
||||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
|
||||||
f.write(code_clean)
|
|
||||||
|
|
||||||
# 启动容器
|
|
||||||
docker = aiodocker.Docker()
|
|
||||||
|
|
||||||
# 检查有没有image
|
|
||||||
image_name = await self.get_image_name()
|
|
||||||
try:
|
|
||||||
await docker.images.get(image_name)
|
|
||||||
except aiodocker.exceptions.DockerError:
|
|
||||||
# 拉取镜像
|
|
||||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
|
||||||
await docker.images.pull(image_name)
|
|
||||||
|
|
||||||
yield event.plain_result(
|
|
||||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.docker_host_astrbot_abs_path = self.config.get(
|
|
||||||
"docker_host_astrbot_abs_path", ""
|
|
||||||
)
|
|
||||||
if self.docker_host_astrbot_abs_path:
|
|
||||||
host_shared = os.path.join(
|
|
||||||
self.docker_host_astrbot_abs_path, self.shared_path
|
|
||||||
)
|
|
||||||
host_output = os.path.join(
|
|
||||||
self.docker_host_astrbot_abs_path, output_path
|
|
||||||
)
|
|
||||||
host_workplace = os.path.join(
|
|
||||||
self.docker_host_astrbot_abs_path, workplace_path
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
host_shared = os.path.abspath(self.shared_path)
|
|
||||||
host_output = os.path.abspath(output_path)
|
|
||||||
host_workplace = os.path.abspath(workplace_path)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}"
|
|
||||||
)
|
|
||||||
|
|
||||||
container = await docker.containers.run(
|
|
||||||
{
|
|
||||||
"Image": image_name,
|
|
||||||
"Cmd": ["python", "exec.py"],
|
|
||||||
"Memory": 512 * 1024 * 1024,
|
|
||||||
"NanoCPUs": 1000000000,
|
|
||||||
"HostConfig": {
|
|
||||||
"Binds": [
|
|
||||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
|
||||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
|
||||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
|
||||||
"AutoRemove": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Container {container.id} created.")
|
|
||||||
logs = await self.run_container(container)
|
|
||||||
|
|
||||||
logger.debug(f"Container {container.id} finished.")
|
|
||||||
logger.debug(f"Container {container.id} logs: {logs}")
|
|
||||||
|
|
||||||
# 发送结果
|
|
||||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
|
||||||
ok = False
|
|
||||||
traceback = ""
|
|
||||||
for idx, log in enumerate(logs):
|
|
||||||
match = re.match(pattern, log)
|
|
||||||
if match:
|
|
||||||
ok = True
|
|
||||||
if match.group(1) == "TEXT":
|
|
||||||
yield event.plain_result(match.group(2))
|
|
||||||
elif match.group(1) == "IMAGE":
|
|
||||||
image_path = os.path.join(workplace_path, match.group(2))
|
|
||||||
logger.debug(f"Sending image: {image_path}")
|
|
||||||
yield event.image_result(image_path)
|
|
||||||
elif match.group(1) == "FILE":
|
|
||||||
file_path = os.path.join(workplace_path, match.group(2))
|
|
||||||
# logger.debug(f"Sending file: {file_path}")
|
|
||||||
# file_s3_url = await self.file_upload(file_path)
|
|
||||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
chain = [File(name=file_name, file=file_path)]
|
|
||||||
yield event.set_result(MessageEventResult(chain=chain))
|
|
||||||
|
|
||||||
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
|
||||||
traceback = "\n".join(logs[idx:])
|
|
||||||
|
|
||||||
if not ok:
|
|
||||||
if traceback:
|
|
||||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# 成功了
|
|
||||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
|
||||||
return
|
|
||||||
|
|
||||||
yield event.plain_result(
|
|
||||||
"经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pi.command("cleanfile")
|
|
||||||
async def pi_cleanfile(self, event: AstrMessageEvent):
|
|
||||||
"""清理用户上传的文件"""
|
|
||||||
for file in self.user_file_msg_buffer[event.get_session_id()]:
|
|
||||||
try:
|
|
||||||
os.remove(file)
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(f"删除文件 {file} 失败: {e}")
|
|
||||||
|
|
||||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
|
||||||
yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。")
|
|
||||||
|
|
||||||
async def run_container(
|
|
||||||
self, container: aiodocker.docker.DockerContainer, timeout: int = 20
|
|
||||||
) -> list[str]:
|
|
||||||
"""Run the container and get the output"""
|
|
||||||
try:
|
|
||||||
await container.wait(timeout=timeout)
|
|
||||||
logs = await container.log(stdout=True, stderr=True)
|
|
||||||
return logs
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"Container {container.id} timeout.")
|
|
||||||
await container.kill()
|
|
||||||
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
|
|
||||||
finally:
|
|
||||||
await container.delete()
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
name: astrbot-python-interpreter
|
|
||||||
desc: Python 代码执行器
|
|
||||||
author: Soulter
|
|
||||||
version: 0.0.1
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
aiodocker
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def _get_magic_code():
|
|
||||||
"""防止注入攻击"""
|
|
||||||
return os.getenv("MAGIC_CODE")
|
|
||||||
|
|
||||||
|
|
||||||
def send_text(text: str):
|
|
||||||
print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}")
|
|
||||||
|
|
||||||
|
|
||||||
def send_image(image_path: str):
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
raise Exception(f"Image file not found: {image_path}")
|
|
||||||
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def send_file(file_path: str):
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise Exception(f"File not found: {file_path}")
|
|
||||||
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")
|
|
||||||
@@ -50,6 +50,7 @@ dependencies = [
|
|||||||
"wechatpy>=1.8.18",
|
"wechatpy>=1.8.18",
|
||||||
"audioop-lts ; python_full_version >= '3.13'",
|
"audioop-lts ; python_full_version >= '3.13'",
|
||||||
"click>=8.2.1",
|
"click>=8.2.1",
|
||||||
|
"shipyard-python-sdk>=0.2.3",
|
||||||
"pypdf>=6.1.1",
|
"pypdf>=6.1.1",
|
||||||
"aiofiles>=25.1.0",
|
"aiofiles>=25.1.0",
|
||||||
"rank-bm25>=0.2.2",
|
"rank-bm25>=0.2.2",
|
||||||
|
|||||||
Reference in New Issue
Block a user