Compare commits

...

5 Commits

18 changed files with 3984 additions and 604 deletions

View File

@@ -151,11 +151,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
for tool_call_name, tool_call_id in zip(
llm_resp.tools_call_name, llm_resp.tools_call_ids
):
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain().message(f"🔨 用工具: {tool_call_name}")
chain=MessageChain().message(f"🔨 正在使用工具: {tool_call_name} ({tool_call_id})")
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
@@ -255,63 +257,48 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async for resp in executor: # type: ignore
if isinstance(resp, CallToolResult):
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
content = res.content
aggr_text_content = ""
for cont in content:
if isinstance(cont, TextContent):
aggr_text_content += cont.text
yield MessageChain().message(cont.text)
elif isinstance(cont, ImageContent):
aggr_text_content += "\n返回了图片(已直接发送给用户)\n"
yield MessageChain(
type="tool_direct_result"
).base64_image(resource.blob)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
).base64_image(cont.data)
elif isinstance(cont, EmbeddedResource):
resource = cont.resource
if isinstance(resource, TextResourceContents):
aggr_text_content += resource.text
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
aggr_text_content += (
"\n返回了图片(已直接发送给用户)\n"
)
yield MessageChain(
type="tool_direct_result"
).base64_image(resource.blob)
else:
aggr_text_content += "\n返回的数据类型不受支持。\n"
yield MessageChain().message(
"返回的数据类型不受支持。"
)
)
yield MessageChain().message("返回的数据类型不受支持。")
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=aggr_text_content,
)
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。

View File

@@ -813,7 +813,8 @@ class File(BaseMessageComponent):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
fname = self.name if self.name else uuid.uuid4().hex
file_path = os.path.join(download_dir, fname)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)

View File

@@ -232,7 +232,9 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file")
file_name = m["data"].get(
"file_name", m["data"].get("file", "file")
)
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:

View File

@@ -11,7 +11,7 @@ from astrbot.core.provider.provider import (
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
@@ -258,6 +258,11 @@ class Context:
return True
return False
def add_llm_tool(self, *tools: FunctionTool) -> None:
"""添加一个 LLM 工具。"""
for tool in tools:
self.provider_manager.llm_tools.func_list.append(tool)
"""
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""

View File

@@ -0,0 +1,10 @@
{
"endpoint": {
"description": "The endpoint URL of the Shipyard server.",
"type": "string"
},
"access_token": {
"description": "The access token for authenticating with the Shipyard server.",
"type": "string"
}
}

View File

@@ -0,0 +1,129 @@
import os
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api import logger
from astrbot.api.all import Context
import astrbot.api.message_components as Comp
from astrbot.core.utils.session_waiter import (
session_waiter,
SessionController,
)
from ..sandbox_client import SandboxClient
class FileCommand:
def __init__(self, context: Context) -> None:
self.context = context
self.user_file_uploads: dict[str, list[str]] = {} # umo -> file_path
self.user_file_uploaded_files: dict[str, list[str]] = {} # umo -> file_path
"""记录用户上传过的文件,保存了文件在沙箱中的路径。这个在用户下一次 LLM 请求时会被用到,然后清空。"""
async def _upload_file_to_sandbox(self, event: AstrMessageEvent) -> list[str]:
"""将用户上传的文件上传到沙箱"""
sender_id = event.get_sender_id()
sb = await SandboxClient().get_ship(event.unified_msg_origin)
fpath_ls = self.user_file_uploads[sender_id]
errors = []
for path in fpath_ls:
try:
fname = os.path.basename(path)
data = await sb.upload_file(path, fname)
success = data.get("success", False)
if not success:
raise Exception(f"Upload failed: {data}")
file_path = data.get("file_path", "")
logger.info(f"File {fname} uploaded to sandbox at {file_path}")
self.user_file_uploaded_files.setdefault(sender_id, []).append(
file_path
)
except Exception as e:
errors.append((path, str(e)))
logger.error(f"Error uploading file {path}: {e}")
# clean up files
for path in fpath_ls:
try:
os.remove(path)
except Exception as e:
logger.error(f"Error removing temp file {path}: {e}")
return errors
async def file(self, event: AstrMessageEvent):
"""等待用户上传文件或图片"""
await event.send(
MessageChain().message(
f"请上传一个或多个文件(或图片),使用 /endupload 结束上传。(请求者 ID: {event.get_sender_id()})"
)
)
try:
@session_waiter(timeout=600, record_history_chains=False) # type: ignore
async def empty_mention_waiter(
controller: SessionController, event: AstrMessageEvent
):
idiom = event.message_str
sender_id = event.get_sender_id()
if idiom == "endupload":
files = self.user_file_uploads.get(sender_id, [])
if not files:
await event.send(
event.plain_result("你没有上传任何文件,上传已取消。")
)
controller.stop()
return
await event.send(
event.plain_result(f"开始上传 {len(files)} 个文件到沙箱...")
)
errors = await self._upload_file_to_sandbox(event)
if errors:
error_msgs = "\n".join(
[f"{path}: {err}" for path, err in errors]
)
await event.send(
event.plain_result(
f"上传中出现错误:\n{error_msgs}\n其他文件已成功上传。"
)
)
else:
await event.send(
event.plain_result(
f"上传完毕,共上传 {len(files)} 个文件。文件信息已被保存,下一次 LLM 请求时会自动将信息附上。"
)
)
self.user_file_uploads.pop(sender_id, None)
controller.stop()
return
# 解析文件或图片消息
for comp in event.message_obj.message:
if isinstance(comp, (Comp.File, Comp.Image)):
if isinstance(comp, Comp.File):
path = await comp.get_file()
self.user_file_uploads.setdefault(
event.get_sender_id(), []
).append(path)
elif isinstance(comp, Comp.Image):
path = await comp.convert_to_file_path()
self.user_file_uploads.setdefault(
event.get_sender_id(), []
).append(path)
fname = os.path.basename(path)
await event.send(
event.plain_result(
f"已接收文件: {fname},继续上传或发送 /endupload 结束。"
)
)
try:
await empty_mention_waiter(event)
except TimeoutError as _:
await event.send(event.plain_result("等待上传超时,上传已取消。"))
except Exception as e:
await event.send(
event.plain_result("发生错误,请联系管理员: " + str(e))
)
finally:
event.stop_event()
except Exception as e:
logger.error("handle_empty_mention error: " + str(e))

View File

@@ -0,0 +1,46 @@
import os
import astrbot.api.star as star
from astrbot.api import logger
from astrbot.api.event import filter, AstrMessageEvent
from astrbot.api.provider import ProviderRequest
from astrbot.api import AstrBotConfig
from .tools.fs import CreateFileTool, ReadFileTool
from .tools.shell import ExecuteShellTool
from .tools.python import PythonTool
from .commands.file import FileCommand
class Main(star.Star):
"""AstrBot Agent"""
def __init__(self, context: star.Context, config: AstrBotConfig) -> None:
self.context = context
self.config = config
self.endpoint = config.get("endpoint", "http://localhost:8000")
self.access_token = config.get("access_token", "")
os.environ["SHIPYARD_ENDPOINT"] = self.endpoint
os.environ["SHIPYARD_ACCESS_TOKEN"] = self.access_token
context.add_llm_tool(
CreateFileTool(), ExecuteShellTool(), PythonTool(), ReadFileTool()
)
self.file_c = FileCommand(context)
async def initialize(self):
pass
@filter.command("fileupload")
async def fileupload(self, event: AstrMessageEvent):
"""处理文件上传"""
await self.file_c.file(event)
@filter.on_llm_request()
async def on_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
"""处理 LLM 请求"""
sender_id = event.get_sender_id()
uploads = self.file_c.user_file_uploaded_files.pop(sender_id, None)
if uploads:
logger.info(f"Attaching uploaded files for user {sender_id}: {uploads}")
req.system_prompt = f"{req.system_prompt}\n\n\n# User Uploaded Files: {uploads}"

View File

@@ -0,0 +1,4 @@
name: astrbot-agent
desc: AstrBot Agent
author: Soulter
version: 0.0.1

View File

@@ -0,0 +1,40 @@
import os
import uuid
from shipyard import ShipyardClient, SessionShip, Spec
from astrbot.api import logger
class SandboxClient:
_instance = None
_initialized = False
session_ship: dict[str, SessionShip] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super(SandboxClient, cls).__new__(cls)
return cls._instance
def __init__(self):
if not SandboxClient._initialized:
self.endpoint = os.getenv("SHIPYARD_ENDPOINT", "http://localhost:8000")
self.access_token = os.getenv("SHIPYARD_ACCESS_TOKEN", "")
self.client = ShipyardClient(
endpoint_url=self.endpoint, access_token=self.access_token
)
SandboxClient._initialized = True
async def get_ship(self, session_id: str) -> SessionShip:
if session_id not in self.session_ship:
uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex
ship = await self.client.create_ship(
ttl=3600,
spec=Spec(cpus=1.0, memory="512m"),
max_session_num=3,
session_id=uuid_str,
)
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
self.session_ship[session_id] = ship
return self.session_ship[session_id]
def get_client(self) -> ShipyardClient:
return self.client

View File

@@ -0,0 +1,60 @@
import json
from astrbot.api import FunctionTool
from astrbot.api.event import AstrMessageEvent
from dataclasses import dataclass, field
from ..sandbox_client import SandboxClient
@dataclass
class CreateFileTool(FunctionTool):
name: str = "astrbot_create_file"
description: str = "Create a new file in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"path": {
"path": "string",
"description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
},
"content": {
"type": "string",
"description": "The content to write into the file.",
},
},
"required": ["path", "content"],
}
)
async def run(self, event: AstrMessageEvent, path: str, content: str):
sb = await SandboxClient().get_ship(event.unified_msg_origin)
try:
result = await sb.fs.create_file(path, content)
return json.dumps(result)
except Exception as e:
return f"Error creating file: {str(e)}"
@dataclass
class ReadFileTool(FunctionTool):
name: str = "astrbot_read_file"
description: str = "Read the content of a file in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
},
},
"required": ["path"],
}
)
async def run(self, event: AstrMessageEvent, path: str):
sb = await SandboxClient().get_ship(event.unified_msg_origin)
try:
result = await sb.fs.read_file(path)
return result
except Exception as e:
return f"Error reading file: {str(e)}"

View File

@@ -0,0 +1,53 @@
import mcp
from astrbot.api import FunctionTool
from astrbot.api.event import AstrMessageEvent
from dataclasses import dataclass, field
from ..sandbox_client import SandboxClient
@dataclass
class PythonTool(FunctionTool):
name: str = "astrbot_execute_ipython"
description: str = "Execute a command in an IPython shell."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The Python code to execute.",
},
"slient": {
"type": "boolean",
"description": "Whether to suppress the output of the code execution.",
"default": False,
},
},
"required": ["code"],
}
)
async def run(self, event: AstrMessageEvent, code: str, silent: bool = False):
sb = await SandboxClient().get_ship(event.unified_msg_origin)
try:
result = await sb.python.exec(code, silent=silent)
output = result.get("output", {})
images: list[dict] = output.get("images", [])
text: str = output.get("text", "")
resp = mcp.types.CallToolResult(content=[])
if images:
for img in images:
resp.content.append(
mcp.types.ImageContent(
type="image", data=img["image/png"], mimeType="image/png"
)
)
if text:
resp.content.append(mcp.types.TextContent(type="text", text=text))
return resp
except Exception as e:
return f"Error executing code: {str(e)}"

View File

@@ -0,0 +1,48 @@
import json
from astrbot.api import FunctionTool
from astrbot.api.event import AstrMessageEvent
from dataclasses import dataclass, field
from ..sandbox_client import SandboxClient
@dataclass
class ExecuteShellTool(FunctionTool):
name: str = "astrbot_execute_shell"
description: str = "Execute a command in the shell."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute.",
},
"background": {
"type": "boolean",
"description": "Whether to run the command in the background.",
"default": False,
},
"env": {
"type": "object",
"description": "Optional environment variables to set for the file creation process.",
"additionalProperties": {"type": "string"},
"default": {},
},
},
"required": ["command"],
}
)
async def run(
self,
event: AstrMessageEvent,
command: str,
background: bool = False,
env: dict = {},
):
sb = await SandboxClient().get_ship(event.unified_msg_origin)
try:
result = await sb.shell.exec(command, background=background, env=env)
return json.dumps(result)
except Exception as e:
return f"Error executing command: {str(e)}"

View File

@@ -1,519 +0,0 @@
import os
import json
import shutil
import aiohttp
import uuid
import asyncio
import re
import aiodocker
import time
import astrbot.api.star as star
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from astrbot.api.event import filter
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Image, File
from astrbot.core.utils.io import download_image_by_url, download_file
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
PROMPT = """
## Task
You need to generate python codes to solve user's problem: {prompt}
{extra_input}
## Limit
1. Available libraries:
- standard libs
- `Pillow`
- `requests`
- `numpy`
- `matplotlib`
- `scipy`
- `scikit-learn`
- `beautifulsoup4`
- `pandas`
- `opencv-python`
- `python-docx`
- `python-pptx`
- `pymupdf` (Do not use fpdf, reportlab, etc.)
- `mplfonts`
You can only use these libraries and the libraries that they depend on.
2. Do not generate malicious code.
3. Use given `shared.api` package to output the result.
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
For Image and file, you must save it to `output` folder.
4. You must only output the code, do not output the result of the code and any other information.
5. The output language is same as user's input language.
6. Please first provide relevant knowledge about user's problem appropriately.
## Example
1. User's problem: `please solve the fabonacci sequence problem.`
Output:
```python
from shared.api import send_text, send_image, send_file
def fabonacci(n):
if n <= 1:
return n
else:
return fabonacci(n-1) + fabonacci(n-2)
result = fabonacci(10)
send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.")
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
```
2. User's problem: `please draw a sin(x) function.`
Output:
```python
from shared.api import send_text, send_image, send_file
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
plt.plot(x, y)
plt.savefig("output/sin_x.png")
send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).")
send_image("output/sin_x.png") # send_image is a function to send image to user
send_text("If you need more information, please let me know :)")
```
{extra_prompt}
"""
DEFAULT_CONFIG = {
"sandbox": {
"image": "soulter/astrbot-code-interpreter-sandbox",
"docker_mirror": "", # cjie.eu.org
},
"docker_host_astrbot_abs_path": "",
}
PATH = os.path.join(get_astrbot_data_path(), "config", "python_interpreter.json")
class Main(star.Star):
"""基于 Docker 沙箱的 Python 代码执行器"""
def __init__(self, context: star.Context) -> None:
self.context = context
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.shared_path = os.path.join("data", "py_interpreter_shared")
if not os.path.exists(self.shared_path):
# 复制 api.py 到 shared 目录
os.makedirs(self.shared_path, exist_ok=True)
shared_api_file = os.path.join(self.curr_dir, "shared", "api.py")
shutil.copy(shared_api_file, self.shared_path)
self.workplace_path = os.path.join("data", "py_interpreter_workplace")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
"""存放用户上传的文件和图片"""
self.user_waiting = {}
"""正在等待用户的文件或图片"""
# 加载配置
if not os.path.exists(PATH):
self.config = DEFAULT_CONFIG
self._save_config()
else:
with open(PATH, "r") as f:
self.config = json.load(f)
async def initialize(self):
ok = await self.is_docker_available()
if not ok:
logger.info(
"Docker 不可用代码解释器将无法使用astrbot-python-interpreter 将自动禁用。"
)
# await self.context._star_manager.turn_off_plugin(
# "astrbot-python-interpreter"
# )
async def file_upload(self, file_path: str):
"""
上传图像文件到 S3
"""
ext = os.path.splitext(file_path)[1]
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
with open(file_path, "rb") as f:
file = f.read()
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(
headers={"Accept": "application/json"}, trust_env=True
) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return s3_file_url
async def is_docker_available(self) -> bool:
"""Check if docker is available"""
try:
docker = aiodocker.Docker()
await docker.version()
await docker.close()
return True
except BaseException as e:
logger.info(f"检查 Docker 可用性: {e}")
return False
async def get_image_name(self) -> str:
"""Get the image name"""
if self.config["sandbox"]["docker_mirror"]:
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
return self.config["sandbox"]["image"]
def _save_config(self):
with open(PATH, "w") as f:
json.dump(self.config, f)
async def gen_magic_code(self) -> str:
return uuid.uuid4().hex[:8]
async def download_image(
self, image_url: str, workplace_path: str, filename: str
) -> str:
"""Download image from url to workplace_path"""
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(image_url) as resp:
if resp.status != 200:
return ""
image_path = os.path.join(workplace_path, f"{filename}.jpg")
with open(image_path, "wb") as f:
f.write(await resp.read())
return f"{filename}.jpg"
async def tidy_code(self, code: str) -> str:
"""Tidy the code"""
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code, re.DOTALL)
if match is None:
raise ValueError("The code is not in the code block.")
return match.group(1)
@filter.event_message_type(filter.EventMessageType.ALL)
async def on_message(self, event: AstrMessageEvent):
"""处理消息"""
uid = event.get_sender_id()
if uid not in self.user_waiting:
return
for comp in event.message_obj.message:
if isinstance(comp, File):
file_path = await comp.get_file()
if file_path.startswith("http"):
name = comp.name if comp.name else uuid.uuid4().hex[:8]
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, name)
await download_file(file_path, path)
else:
path = file_path
self.user_file_msg_buffer[event.get_session_id()].append(path)
logger.debug(f"User {uid} uploaded file: {path}")
yield event.plain_result(f"代码执行器: 文件已经上传: {path}")
if uid in self.user_waiting:
del self.user_waiting[uid]
elif isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
else:
image_path = image_url
self.user_file_msg_buffer[event.get_session_id()].append(image_path)
logger.debug(f"User {uid} uploaded image: {image_path}")
yield event.plain_result(f"代码执行器: 图片已经上传: {image_path}")
if uid in self.user_waiting:
del self.user_waiting[uid]
@filter.on_llm_request()
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
pass
@pi.command("absdir")
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
"""设置 Docker 宿主机绝对路径"""
if not path:
yield event.plain_result(
f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}"
)
else:
self.config["docker_host_astrbot_abs_path"] = path
self._save_config()
yield event.plain_result(f"设置 Docker 宿主机绝对路径成功: {path}")
@pi.command("mirror")
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
"""Docker 镜像地址"""
if not url:
yield event.plain_result(f"""当前 Docker 镜像地址: {self.config["sandbox"]["docker_mirror"]}
使用 `pi mirror <url>` 来设置 Docker 镜像地址。
您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。
""")
else:
self.config["sandbox"]["docker_mirror"] = url
self._save_config()
yield event.plain_result("设置 Docker 镜像地址成功。")
@pi.command("repull")
async def pi_repull(self, event: AstrMessageEvent):
"""重新拉取沙箱镜像"""
docker = aiodocker.Docker()
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
await docker.images.delete(image_name, force=True)
except aiodocker.exceptions.DockerError:
pass
await docker.images.pull(image_name)
yield event.plain_result("重新拉取沙箱镜像成功。")
@pi.command("file")
async def pi_file(self, event: AstrMessageEvent):
"""在规定秒数(60s)内上传一个文件"""
uid = event.get_sender_id()
self.user_waiting[uid] = time.time()
tip = "文件"
yield event.plain_result(f"代码执行器: 请在 60s 内上传一个{tip}")
await asyncio.sleep(60)
if uid in self.user_waiting:
yield event.plain_result(
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}"
)
self.user_waiting.pop(uid)
@pi.command("clear", alias=["clean"])
async def pi_file_clean(self, event: AstrMessageEvent):
"""清理用户上传的文件"""
uid = event.get_sender_id()
if uid in self.user_waiting:
self.user_waiting.pop(uid)
yield event.plain_result(
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 已清理。"
)
else:
yield event.plain_result(
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有等待上传文件。"
)
@pi.command("list")
async def pi_file_list(self, event: AstrMessageEvent):
"""列出用户上传的文件"""
uid = event.get_sender_id()
if uid in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[uid]
yield event.plain_result(
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 上传的文件: {files}"
)
else:
yield event.plain_result(
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有上传文件。"
)
@llm_tool("python_interpreter")
async def python_interpreter(self, event: AstrMessageEvent):
"""Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
"""
if not await self.is_docker_available():
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
plain_text = event.message_str
# 创建必要的工作目录和幻术码
magic_code = await self.gen_magic_code()
workplace_path = os.path.join(self.workplace_path, magic_code)
output_path = os.path.join(workplace_path, "output")
os.makedirs(workplace_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
files = []
# 文件
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
if not file_path:
continue
elif not os.path.exists(file_path):
logger.warning(f"文件 {file_path} 不存在,已忽略。")
continue
# cp
file_name = os.path.basename(file_path)
shutil.copy(file_path, os.path.join(workplace_path, file_name))
files.append(file_name)
logger.debug(f"user query: {plain_text}, files: {files}")
# 整理额外输入
extra_inputs = ""
if files:
extra_inputs += f"User provided files: {files}\n"
obs = ""
n = 5
for i in range(n):
if i > 0:
logger.info(f"Try {i + 1}/{n}")
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
llm_response = await provider.text_chat(
prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}"
)
logger.debug(
"code interpreter llm gened code:" + llm_response.completion_text
)
# 整理代码并保存
code_clean = await self.tidy_code(llm_response.completion_text)
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
f.write(code_clean)
# 启动容器
docker = aiodocker.Docker()
# 检查有没有image
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
except aiodocker.exceptions.DockerError:
# 拉取镜像
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
await docker.images.pull(image_name)
yield event.plain_result(
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})"
)
self.docker_host_astrbot_abs_path = self.config.get(
"docker_host_astrbot_abs_path", ""
)
if self.docker_host_astrbot_abs_path:
host_shared = os.path.join(
self.docker_host_astrbot_abs_path, self.shared_path
)
host_output = os.path.join(
self.docker_host_astrbot_abs_path, output_path
)
host_workplace = os.path.join(
self.docker_host_astrbot_abs_path, workplace_path
)
else:
host_shared = os.path.abspath(self.shared_path)
host_output = os.path.abspath(output_path)
host_workplace = os.path.abspath(workplace_path)
logger.debug(
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}"
)
container = await docker.containers.run(
{
"Image": image_name,
"Cmd": ["python", "exec.py"],
"Memory": 512 * 1024 * 1024,
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{host_shared}:/astrbot_sandbox/shared:ro",
f"{host_output}:/astrbot_sandbox/output:rw",
f"{host_workplace}:/astrbot_sandbox:rw",
]
},
"Env": [f"MAGIC_CODE={magic_code}"],
"AutoRemove": True,
}
)
logger.debug(f"Container {container.id} created.")
logs = await self.run_container(container)
logger.debug(f"Container {container.id} finished.")
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
match = re.match(pattern, log)
if match:
ok = True
if match.group(1) == "TEXT":
yield event.plain_result(match.group(2))
elif match.group(1) == "IMAGE":
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
# logger.debug(f"Sending file: {file_path}")
# file_s3_url = await self.file_upload(file_path)
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_path)]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}"
)
break
else:
# 成功了
self.user_file_msg_buffer.pop(event.get_session_id())
return
yield event.plain_result(
"经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。"
)
@pi.command("cleanfile")
async def pi_cleanfile(self, event: AstrMessageEvent):
"""清理用户上传的文件"""
for file in self.user_file_msg_buffer[event.get_session_id()]:
try:
os.remove(file)
except BaseException as e:
logger.error(f"删除文件 {file} 失败: {e}")
self.user_file_msg_buffer.pop(event.get_session_id())
yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。")
async def run_container(
self, container: aiodocker.docker.DockerContainer, timeout: int = 20
) -> list[str]:
"""Run the container and get the output"""
try:
await container.wait(timeout=timeout)
logs = await container.log(stdout=True, stderr=True)
return logs
except asyncio.TimeoutError:
logger.warning(f"Container {container.id} timeout.")
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()

View File

@@ -1,4 +0,0 @@
name: astrbot-python-interpreter
desc: Python 代码执行器
author: Soulter
version: 0.0.1

View File

@@ -1 +0,0 @@
aiodocker

View File

@@ -1,22 +0,0 @@
import os
def _get_magic_code():
"""防止注入攻击"""
return os.getenv("MAGIC_CODE")
def send_text(text: str):
print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}")
def send_image(image_path: str):
if not os.path.exists(image_path):
raise Exception(f"Image file not found: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
def send_file(file_path: str):
if not os.path.exists(file_path):
raise Exception(f"File not found: {file_path}")
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")

View File

@@ -50,6 +50,7 @@ dependencies = [
"wechatpy>=1.8.18",
"audioop-lts ; python_full_version >= '3.13'",
"click>=8.2.1",
"shipyard-python-sdk>=0.2.3",
"pypdf>=6.1.1",
"aiofiles>=25.1.0",
"rank-bm25>=0.2.2",

3540
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff