Compare commits

..

15 Commits

Author SHA1 Message Date
LIghtJUNction
c56d2aeced 文档更新 2025-11-03 23:35:22 +08:00
LIghtJUNction
fb56ac9f47 依赖注入示例 2025-11-03 23:28:13 +08:00
LIghtJUNction
5719dbd8b2 Update abc.py 2025-11-03 20:46:32 +08:00
LIghtJUNction
5217450c8a init 2025-11-03 20:28:55 +08:00
LIghtJUNction
15aafa2043 Update astrbot/core/config/default.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-03 14:26:00 +08:00
Copilot
36a3b4d318 chore: 运行 ruff format 修复代码格式 (#3291)
* Initial plan

* chore: 运行 ruff format 修复代码格式

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 14:25:01 +08:00
LIghtJUNction
7518c4a057 Update astrbot/core/utils/astrbot_path.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-03 12:36:47 +08:00
LIghtJUNction
8e3bd92c09 Update astrbot/core/utils/astrbot_path.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-03 03:14:37 +08:00
Copilot
9e4fec8488 [WIP] Update path class and deprecate old functions (#3287)
* Initial plan

* feat: 为 AstrbotPaths 添加全面测试,覆盖率达到 100%

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 03:06:53 +08:00
Copilot
983264dc1a fix: 移除未使用的临时目录变量 (#3286)
* Initial plan

* fix: 移除未使用的临时目录变量 (ruff check --fix)

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 02:51:12 +08:00
Copilot
b3f23b23da chore: 使用 ruff format 格式化代码 (#3283)
* Initial plan

* chore: 使用 ruff format 格式化代码

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 02:31:47 +08:00
LIghtJUNction
212d8cc101 版本号统一,常量移动至base包 (#3281)
Co-authored-by: 赵天乐(tyler zhao) <189870321+tyler-ztl@users.noreply.github.com>
2025-11-03 00:02:21 +08:00
LIghtJUNction
36de3c541a Update astrbot/base/paths.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-11-02 22:57:05 +08:00
LIghtJUNction
1151677f13 fix/ 抽象基类装饰器顺序写反了
Co-Authored-By: 赵天乐(tyler zhao) <189870321+tyler-ztl@users.noreply.github.com>
2025-11-02 22:54:43 +08:00
LIghtJUNction
2ccb85d802 统一路径类,旧函数已弃用
.env文件保存统一的路径
静态文件一起打包而不是从上上上级目录找...

Co-Authored-By: 赵天乐(tyler zhao) <189870321+tyler-ztl@users.noreply.github.com>
2025-11-02 22:47:22 +08:00
154 changed files with 3450 additions and 3444 deletions

View File

@@ -1,9 +1,9 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# github actions
.git
# github acions
.github/
.*ignore
.git/
# User-specific stuff
.idea/
# Byte-compiled / optimized / DLL files
@@ -15,10 +15,10 @@ env/
venv*/
ENV/
.conda/
README*.md
dashboard/
data/
changelogs/
tests/
.ruff_cache/
.astrbot
astrbot.lock
.astrbot

1
.env Normal file
View File

@@ -0,0 +1 @@
ASTRBOT_ROOT = ./data

2
.env.example Normal file
View File

@@ -0,0 +1,2 @@
# ASTRBOT 数据目录
# ASTRBOT_ROOT = ./data

View File

@@ -16,7 +16,7 @@ body:
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?可以从 [此](https://plugins.astrbot.app) 右下角提交。
不熟悉 JSON ?可以从 [此](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
- type: textarea
id: plugin-info

View File

@@ -1,63 +1,63 @@
# AstrBot Development Instructions
# AstrBot 开发指南
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
AstrBot 是一个使用 Python 编写、配备 Vue.js 仪表盘的多平台 LLM 聊天机器人开发框架。它支持多个消息平台QQ、TelegramDiscord 等)和多种 LLM 提供商(OpenAIAnthropicGoogle Gemini 等)。
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
始终优先参考这些指南,仅在遇到与此处信息不符的意外情况时才回退到搜索或 bash 命令。
## Working Effectively
## 高效工作
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### 引导和安装依赖
- **需要 Python 3.10+** - 检查 `.python-version` 文件
- 安装 UV 包管理器:`pip install uv`
- 安装项目依赖:`uv sync` -- 很快几分钟。绝不要取消。设置超时时间为 10+ 分钟。
- 创建必需的目录:`mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### 运行应用程序
- 运行主应用程序:`uv run main.py` -- 约 3 秒启动
- 应用程序在 http://localhost:6185 创建 WebUI默认凭据`astrbot`/`astrbot`
- 应用程序自动从 `packages/` `data/plugins/` 目录加载插件
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### 仪表盘构建(Vue.js/Node.js
- **前置要求**:需要 Node.js 20+ npm 10+
- 导航到仪表盘:`cd dashboard`
- 安装仪表盘依赖:`npm install` -- 需要 2-3 分钟。绝不要取消。设置超时时间为 5+ 分钟。
- 构建仪表盘:`npm run build` -- 需要 25-30 秒。绝不要取消。
- 仪表盘在 `dashboard/dist/` 创建优化的生产构建
### Testing
- Do not generate test files for now.
### 测试
- 暂时不要生成测试文件。
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### 代码质量和检查
- 安装 ruff 检查器:`uv add --dev ruff`
- 检查代码风格:`uv run ruff check .` -- 耗时 <1 秒
- 检查格式:`uv run ruff format --check .` -- 耗时 <1 秒
- 修复格式:`uv run ruff format .`
- **始终**在提交更改前运行 `uv run ruff check .` `uv run ruff format .`
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### 插件开发
- 插件从 `packages/`(内置)和 `data/plugins/`(用户安装)加载
- 插件系统支持函数工具和消息处理器
- 关键插件:python_interpreterweb_searcherastrbotremindersession_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
### 常见问题和解决方法
- **仪表盘下载失败**:已知的"除以零"错误问题 - 应用程序仍可正常工作
- **测试中的导入错误**:确保使用 `uv run` 在适当的环境中运行测试
- **构建超时**始终设置适当的超时时间uv sync 为 10+ 分钟,npm install 为 5+ 分钟)
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## CI/CD 集成
- GitHub Actions 工作流在 `.github/workflows/`
- 通过 `Dockerfile` 支持 Docker 构建
- Pre-commit 钩子强制执行 ruff 格式化和检查
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Docker 支持
- 主要部署方法:`docker run soulter/astrbot:latest`
- 可用的 Compose 文件:`compose.yml`
- 暴露端口:6185WebUI)、6195WeChat)、6199QQ
- 需要挂载卷:`./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
## 多语言支持
- 文档包括中文README.md、英文README_en.md、日文README_ja.md
- UI 支持国际化
- 默认语言为中文
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
请记住:这是一个有真实用户的生产聊天机器人框架。始终进行彻底测试,确保更改不会破坏现有功能。

View File

@@ -12,21 +12,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
bash \
ffmpeg \
curl \
gnupg \
git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y curl gnupg \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y nodejs
RUN apt-get update && apt-get install -y curl gnupg && \
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv \
&& echo "3.11" > .python-version
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pilk --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]
CMD [ "python", "main.py" ]

35
Dockerfile_with_node Normal file
View File

@@ -0,0 +1,35 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]

116
README.md
View File

@@ -8,7 +8,7 @@
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
@@ -119,73 +119,83 @@ uv run main.py
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## 支持的消息平台
## 消息平台支持情况
**官方维护**
- QQ (官方平台 & OneBot)
- Telegram
- 企微应用 & 企微智能机器人
- 微信客服 & 微信公众号
- 飞书
- 钉钉
- Slack
- Discord
- Satori
- Misskey
- Whatsapp (将支持)
- LINE (将支持)
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方平台) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企微应用 | ✔ |
| 企微智能机器人 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
## 支持的模型服务
## ⚡ 提供商支持情况
**大模型服务**
- OpenAI 及兼容服务
- Anthropic
- Google Gemini
- Moonshot AI
- 智谱 AI
- DeepSeek
- Ollama (本地部署)
- LM Studio (本地部署)
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [小马算力](https://www.tokenpony.cn/3YPyf)
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**LLMOps 平台**
- Dify
- 阿里云百炼应用
- Coze
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
- OpenAI Whisper
- SenseVoice
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- 阿里云百炼 TTS
- Azure TTS
- Minimax TTS
- 火山引擎 TTS
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献
@@ -219,7 +229,7 @@ pre-commit install
## ⭐ Star History
> [!TIP]
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我们维护这个开源项目的动力 <3
<div align="center">

97
astrbot/__main__.py Normal file
View File

@@ -0,0 +1,97 @@
import argparse
import asyncio
import mimetypes
import os
import sys
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.config.default import VERSION
from astrbot.core.initial_loader import InitialLoader
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot_api import LOGO, IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
def check_env() -> None:
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
logger.error("请使用 Python3.10+ 运行本项目。")
exit()
# os.makedirs("data/config", exist_ok=True)
# os.makedirs("data/plugins", exist_ok=True)
# os.makedirs("data/temp", exist_ok=True)
# 针对问题 #181 的临时解决方案
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("text/javascript", ".mjs")
mimetypes.add_type("application/json", ".json")
async def check_dashboard_files(webui_dir: str | None = None):
"""下载管理面板文件"""
# 指定webui目录
if webui_dir:
if os.path.exists(webui_dir):
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
return webui_dir
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
data_dist_path = str(AstrbotPaths.astrbot_root / "dist")
if os.path.exists(data_dist_path):
v = await get_dashboard_version()
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI 版本已是最新。")
else:
logger.warning(
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
)
return data_dist_path
logger.info(
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip并将其中的 dist 文件夹解压至 data 目录下。",
)
try:
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return data_dist_path
def main():
parser = argparse.ArgumentParser(description="AstrBot")
parser.add_argument(
"--webui-dir",
type=str,
help="指定 WebUI 静态文件目录路径",
default=None,
)
args = parser.parse_args()
check_env()
# 启动日志代理
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# 检查仪表板文件
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
db = db_helper
# 打印 logo
logger.info(LOGO)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
asyncio.run(core_lifecycle.start())
if __name__ == "__main__":
main()

View File

@@ -36,8 +36,7 @@ from astrbot.core.star.config import *
# provider
from astrbot.core.provider import Provider, ProviderMetaData
from astrbot.core.db.po import Personality
from astrbot.core.provider import Provider, Personality, ProviderMetaData
# platform
from astrbot.core.platform import (

View File

@@ -1,5 +1,4 @@
from astrbot.core.db.po import Personality
from astrbot.core.provider import Provider, STTProvider
from astrbot.core.provider import Personality, Provider, STTProvider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMetaData,

View File

@@ -1 +1 @@
__version__ = "3.5.23"
"""AstrBot CLI入口"""

View File

@@ -1,27 +1,22 @@
"""AstrBot CLI入口"""
import sys
from importlib.metadata import version
import click
from . import __version__
from astrbot_api import LOGO
from .commands import conf, init, plug, run
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
__version__ = version("astrbot")
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo(LOGO)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")

View File

@@ -3,6 +3,11 @@ import shutil
from pathlib import Path
import click
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from ..utils import (
PluginStatus,
@@ -47,8 +52,7 @@ def display_plugins(plugins, title=None, color=None):
@click.argument("name")
def new(name: str):
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
plug_path = AstrbotPaths.getPaths(name).plugins
if plug_path.exists():
raise click.ClickException(f"插件 {name} 已存在")

View File

@@ -1,22 +1,31 @@
import warnings
from pathlib import Path
import click
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths = sync_base_container.get(type[IAstrbotPaths])
def check_astrbot_root(path: str | Path) -> bool:
"""检查路径是否为 AstrBot 根目录"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
return False
if not (path / ".astrbot").exists():
return False
return True
warnings.warn(
"请使用 AstrbotPaths 类代替本模块中的函数",
DeprecationWarning,
stacklevel=2,
)
return AstrbotPaths.is_root(Path(path))
def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径"""
return Path.cwd()
warnings.warn(
"请使用 AstrbotPaths 类代替本模块中的函数",
DeprecationWarning,
stacklevel=2,
)
return AstrbotPaths.astrbot_root
async def check_dashboard(astrbot_root: Path) -> None:

View File

@@ -9,10 +9,6 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from .log import LogBroker, LogManager # noqa
from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", False)

View File

@@ -1,21 +1,16 @@
from dataclasses import dataclass
from typing import Any, Generic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypeVar
from .message import Message
TContext = TypeVar("TContext", default=Any)
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -40,13 +40,6 @@ class BaseAgentRunner(T.Generic[TContext]):
"""Process a single step of the agent."""
...
@abc.abstractmethod
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
...
@abc.abstractmethod
def done(self) -> bool:
"""Check if the agent has completed its task.

View File

@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
@@ -55,20 +55,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.agent_hooks = agent_hooks
self.run_context = run_context
messages = []
# append existing messages in the run context
for msg in request.contexts:
messages.append(Message.model_validate(msg))
if request.prompt is not None:
m = await request.assemble_context()
messages.append(Message.model_validate(m))
if request.system_prompt:
messages.insert(
0,
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
@@ -110,22 +96,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
elif llm_response.completion_text:
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text),
),
)
elif llm_response.reasoning_content:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain(type="reasoning").message(
llm_response.reasoning_content,
),
),
)
continue
llm_resp_result = llm_response
break # got final response
@@ -153,13 +130,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# record the final assistant message
self.run_context.messages.append(
Message(
role="assistant",
content=llm_resp.completion_text or "",
),
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
@@ -186,16 +156,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain(type="tool_call").message(
f"🔨 调用工具: {tool_call_name}"
),
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
result.type = "tool_call_result"
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
@@ -208,23 +175,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
tool_calls_result=tool_call_result_blocks,
)
# record the assistant message with tool calls
self.run_context.messages.extend(
tool_calls_result.to_openai_messages_model()
)
self.req.append_tool_calls_result(tool_calls_result)
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
step_count = 0
while not self.done() and step_count < max_step:
step_count += 1
async for resp in self.step():
yield resp
async def _handle_function_tools(
self,
req: ProviderRequest,

View File

@@ -4,13 +4,12 @@ from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated
from pydantic import Field, model_validator
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
ToolExecResult = str | mcp.types.CallToolResult
@dataclass
@@ -56,14 +55,23 @@ class FunctionTool(ToolSchema, Generic[TContext]):
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
def __dict__(self) -> dict[str, Any]:
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
}
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> str | mcp.types.CallToolResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
@dataclass
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -71,7 +79,8 @@ class ToolSet:
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
tools: list[FunctionTool] = Field(default_factory=list)
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""

View File

@@ -1,19 +1,14 @@
from pydantic import Field
from pydantic.dataclasses import dataclass
from dataclasses import dataclass
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.context import Context
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class AstrAgentContext:
context: Context
"""The star context instance"""
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
event: AstrMessageEvent
"""The message event associated with the agent context."""
extra: dict[str, str] = Field(default_factory=dict)
"""Customized extra data."""
AgentContextWrapper = ContextWrapper[AstrAgentContext]

View File

@@ -1,36 +0,0 @@
from typing import Any
from mcp.types import CallToolResult
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.pipeline.context_utils import call_event_hook
from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
pass
MAIN_AGENT_HOOKS = MainAgentHooks()

View File

@@ -1,80 +0,0 @@
import traceback
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
stream_to_general: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use:
await astr_event.send(resp.data["chain"])
continue
if stream_to_general and resp.type == "streaming_delta":
continue
if stream_to_general or not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
chain = resp.data["chain"]
if chain.type == "reasoning" and not show_reasoning:
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return

View File

@@ -1,246 +0,0 @@
import asyncio
import inspect
import traceback
import typing as T
import mcp
from astrbot import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
MessageEventResult,
)
from astrbot.core.provider.register import llm_tools
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input")
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
ctx = run_context.context.context
event = run_context.context.event
umo = event.unified_msg_origin
prov_id = await ctx.get_current_chat_provider_id(umo)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
prompt=input_,
system_prompt=tool.agent.instructions,
tools=toolset,
max_steps=30,
run_hooks=tool.agent.run_hooks,
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
)
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -3,11 +3,15 @@ import json
import logging
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
ASTRBOT_CONFIG_PATH = str(AstrbotPaths.astrbot_root / "cmd_config.json")
logger = logging.getLogger("astrbot")

View File

@@ -1,11 +1,17 @@
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
import os
from importlib.metadata import version
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
VERSION = "4.5.7"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
# 警告,请使用version函数获取版本,此变量兼容保留
VERSION = version("astrbot")
DB_PATH = str(AstrbotPaths.astrbot_root / "data_v4.db")
# 默认配置
DEFAULT_CONFIG = {
@@ -68,7 +74,7 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"unsupported_streaming_strategy": "realtime_segmenting",
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
@@ -740,7 +746,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
@@ -756,7 +761,6 @@ CONFIG_METADATA_2 = {
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -770,7 +774,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"xai_native_search": False,
"modalities": ["text", "image", "tool_use"],
@@ -802,7 +805,6 @@ CONFIG_METADATA_2 = {
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -817,7 +819,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "llama-3.1-8b",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -834,7 +835,6 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -876,24 +876,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"Groq": {
"id": "groq_default",
"provider": "groq",
"type": "groq_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.groq.com/openai/v1",
"timeout": 120,
"model_config": {
"model": "openai/gpt-oss-20b",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
@@ -907,7 +889,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -924,7 +905,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -941,7 +921,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
},
"小马算力": {
@@ -957,7 +936,6 @@ CONFIG_METADATA_2 = {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_headers": {},
"custom_extra_body": {},
},
"优云智算": {
@@ -972,7 +950,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -986,7 +963,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -1002,8 +978,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "glm-4-flash",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Dify": {
@@ -1060,7 +1034,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -1073,7 +1046,6 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60,
"custom_headers": {},
"custom_extra_body": {},
},
"Whisper(API)": {
@@ -1355,12 +1327,6 @@ CONFIG_METADATA_2 = {
"render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
},
"custom_headers": {
"description": "自定义添加请求头",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
},
"custom_extra_body": {
"description": "自定义请求体参数",
"type": "dict",
@@ -2010,8 +1976,8 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
"streaming_segmented": {
"type": "bool",
},
"max_agent_step": {
"description": "工具调用轮数上限",
@@ -2316,15 +2282,9 @@ CONFIG_METADATA_3 = {
"description": "流式回复",
"type": "bool",
},
"provider_settings.unsupported_streaming_strategy": {
"description": "不支持流式回复的平台",
"type": "string",
"options": ["realtime_segmenting", "turn_off"],
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
"labels": ["实时分段回复", "关闭流式回复"],
"condition": {
"provider_settings.streaming_response": True,
},
"provider_settings.streaming_segmented": {
"description": "不支持流式回复的平台采取分段输出",
"type": "bool",
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",

View File

@@ -1,9 +0,0 @@
from __future__ import annotations
class AstrBotError(Exception):
"""Base exception for all AstrBot errors."""
class ProviderNotFoundError(AstrBotError):
"""Raised when a specified provider is not found."""

View File

@@ -21,20 +21,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import asyncio
import base64
import json
import os
import uuid
from enum import Enum
from astrbot_api.abc import IAstrbotPaths
from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
class ComponentType(str, Enum):
# Basic Segment Types
Plain = "Plain" # plain text message
@@ -153,8 +153,7 @@ class Record(BaseMessageComponent):
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
file_path = str(AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.jpg")
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
@@ -242,8 +241,9 @@ class Video(BaseMessageComponent):
if url and url.startswith("file:///"):
return url[8:]
if url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
video_file_path = str(
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4().hex}"
)
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
@@ -442,8 +442,9 @@ class Image(BaseMessageComponent):
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
image_file_path = str(
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.jpg"
)
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
@@ -527,7 +528,7 @@ class Reply(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
type = ComponentType.Poke
id: int | None = 0
qq: int | None = 0
@@ -654,33 +655,19 @@ class File(BaseMessageComponent):
@property
def file(self) -> str:
"""获取文件路径如果文件不存在但有URL则同步下载文件
"""获取本地文件路径(仅返回已存在的文件
⚠️ 警告:此属性不会自动下载文件!
- 如果文件已存在,返回绝对路径
- 如果只有 URL 没有本地文件,返回空字符串
- 需要下载文件请使用 `await get_file()` 方法
Returns:
str: 文件路径
str: 文件的绝对路径,如果文件不存在则返回空字符串
"""
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@file.setter
@@ -714,15 +701,18 @@ class File(BaseMessageComponent):
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
if self.file_:
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
if not self.url:
raise ValueError("No URL provided for download")
download_dir = str(AstrbotPaths.astrbot_root / "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
file_path = str(AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4().hex}")
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
@dataclass
@@ -15,3 +15,4 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
call_local_llm_tool = call_local_llm_tool

View File

@@ -3,6 +3,8 @@ import traceback
import typing as T
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
@@ -105,3 +107,66 @@ async def call_event_hook(
return True
return event.is_stopped()
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -3,10 +3,20 @@
import asyncio
import copy
import json
import traceback
from collections.abc import AsyncGenerator
from typing import Any
from mcp.types import CallToolResult
from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import Image
@@ -21,19 +31,324 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
from ....astr_agent_context import AgentContextWrapper
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....astr_agent_run_util import AgentRunner, run_agent
from ....astr_agent_tool_exec import FunctionToolExecutor
from ...context import PipelineContext, call_event_hook
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
from ..stage import Stage
from ..utils import inject_kb_context
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
event=run_context.context.event,
)
event = run_context.context.event
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=run_context.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
logger.debug(f"Found call in: {ty}")
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
@@ -48,15 +363,11 @@ class LLMRequestSubStage(Stage):
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_reasoning = settings.get("display_reasoning_text", False)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -95,12 +406,63 @@ class LLMRequestSubStage(Stage):
raise RuntimeError("无法创建新的对话。")
return conversation
async def _apply_kb_context(
async def process(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""应用知识库上下文到请求中"""
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin,
@@ -110,40 +472,43 @@ class LLMRequestSubStage(Stage):
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if len(contexts) // 2 <= self.max_context_length:
return contexts
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
req.contexts = req.contexts[index:]
return truncated_contexts
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
def _modalities_fix(
self,
provider: Provider,
req: ProviderRequest,
):
"""检查提供商的模态能力,清理请求中的不支持内容"""
# fix messages
req.contexts = self.fix_messages(req.contexts)
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
@@ -157,13 +522,7 @@ class LLMRequestSubStage(Stage):
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
)
req.func_tool = None
def _plugin_tool_fix(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""根据事件中的插件设置,过滤请求中的工具列表"""
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
@@ -177,6 +536,80 @@ class LLMRequestSubStage(Stage):
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
)
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use),
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)
async def _handle_webchat(
self,
event: AstrMessageEvent,
@@ -224,6 +657,9 @@ class LLMRequestSubStage(Stage):
),
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
)
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
@@ -251,9 +687,6 @@ class LLMRequestSubStage(Stage):
logger.debug("LLM 响应为空,不保存记录。")
return
if req.contexts is None:
req.contexts = []
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
@@ -273,7 +706,7 @@ class LLMRequestSubStage(Stage):
history=messages,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
@@ -288,184 +721,3 @@ class LLMRequestSubStage(Stage):
else:
fixed_messages.append(message)
return fixed_messages
async def process(
self,
event: AstrMessageEvent,
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
logger.debug("ready to request llm provider")
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix and not event.message_str.startswith(
self.provider_wake_prefix
):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# apply knowledge base context
await self._apply_kb_context(event, req)
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# check provider modalities, if provider does not support image/tool_use, clear them in request.
self._modalities_fix(provider, req)
# filter tools, only keep tools from this pipeline's selected plugins
self._plugin_tool_fix(event, req)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
show_reasoning=self.show_reasoning,
),
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain()
.message(final_llm_resp.completion_text)
.chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
stream_to_general,
show_reasoning=self.show_reasoning,
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)

View File

@@ -1,7 +1,6 @@
import asyncio
import math
import random
from collections.abc import AsyncGenerator
import astrbot.core.message.components as Comp
from astrbot.core import logger
@@ -10,6 +9,7 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
from ..context import PipelineContext, call_event_hook
from ..stage import Stage, register_stage
@@ -152,7 +152,7 @@ class RespondStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
) -> None:
result = event.get_result()
if result is None:
return
@@ -168,15 +168,12 @@ class RespondStage(Stage):
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
realtime_segmenting = (
self.config.get("provider_settings", {}).get(
"unsupported_streaming_strategy",
"realtime_segmenting",
)
== "realtime_segmenting"
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented",
False,
)
logger.info(f"应用流式输出({event.get_platform_id()})")
await event.send_streaming(result.async_stream, realtime_segmenting)
await event.send_streaming(result.async_stream, use_fallback)
return
if len(result.chain) > 0:
# 检查路径映射
@@ -220,20 +217,21 @@ class RespondStage(Stage):
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
)
return
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
else:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}

View File

@@ -16,6 +16,3 @@ class PlatformMetadata:
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str | None = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
support_streaming_message: bool = True
"""平台是否支持真实流式传输"""

View File

@@ -14,7 +14,6 @@ def register_platform_adapter(
default_config_tmpl: dict | None = None,
adapter_display_name: str | None = None,
logo_path: str | None = None,
support_streaming_message: bool = True,
):
"""用于注册平台适配器的带参装饰器。
@@ -43,7 +42,6 @@ def register_platform_adapter(
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
support_streaming_message=support_streaming_message,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -29,7 +29,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
@register_platform_adapter(
"aiocqhttp",
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
support_streaming_message=False,
)
class AiocqhttpAdapter(Platform):
def __init__(
@@ -50,7 +49,6 @@ class AiocqhttpAdapter(Platform):
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
support_streaming_message=False,
)
self.bot = CQHttp(
@@ -109,7 +107,7 @@ class AiocqhttpAdapter(Platform):
)
await super().send_by_session(session, message_chain)
async def convert_message(self, event: Event) -> AstrBotMessage | None:
async def convert_message(self, event: Event) -> AstrBotMessage:
logger.debug(f"[aiocqhttp] RawMessage {event}")
if event["post_type"] == "message":
@@ -224,7 +222,7 @@ class AiocqhttpAdapter(Platform):
err = f"aiocqhttp: 无法识别的消息类型: {event.message!s},此条消息将被忽略。如果您在使用 go-cqhttp请将其配置文件中的 message.post-format 更改为 array。"
logger.critical(err)
try:
await self.bot.send(event, err)
self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return None

View File

@@ -37,9 +37,7 @@ class MyEventHandler(dingtalk_stream.EventHandler):
return AckMessage.STATUS_OK, "OK"
@register_platform_adapter(
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
)
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
class DingtalkPlatformAdapter(Platform):
def __init__(
self,
@@ -76,14 +74,6 @@ class DingtalkPlatformAdapter(Platform):
)
self.client_ = client # 用于 websockets 的 client
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
if not dingtalk_id:
return dingtalk_id
prefix = "$:LWCP_v1:$"
if dingtalk_id.startswith(prefix):
return dingtalk_id[len(prefix) :]
return dingtalk_id
async def send_by_session(
self,
session: MessageSesion,
@@ -96,7 +86,6 @@ class DingtalkPlatformAdapter(Platform):
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
support_streaming_message=False,
)
async def convert_msg(
@@ -113,10 +102,10 @@ class DingtalkPlatformAdapter(Platform):
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
user_id=self._id_to_sid(message.sender_id),
user_id=message.sender_id,
nickname=message.sender_nick,
)
abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.self_id = message.chatbot_user_id
abm.message_id = message.message_id
abm.raw_message = message
@@ -124,8 +113,8 @@ class DingtalkPlatformAdapter(Platform):
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
if message.at_users:
for user in message.at_users:
if id := self._id_to_sid(user.dingtalk_id):
abm.message.append(At(qq=id))
if user.dingtalk_id:
abm.message.append(At(qq=user.dingtalk_id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id

View File

@@ -34,9 +34,7 @@ else:
# 注册平台适配器
@register_platform_adapter(
"discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
)
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
class DiscordPlatformAdapter(Platform):
def __init__(
self,
@@ -92,7 +90,7 @@ class DiscordPlatformAdapter(Platform):
)
message_obj.self_id = self.client_self_id
message_obj.session_id = session.session_id
message_obj.message = message_chain.chain
message_obj.message = message_chain
# 创建临时事件对象来发送消息
temp_event = DiscordPlatformEvent(
@@ -113,7 +111,6 @@ class DiscordPlatformAdapter(Platform):
"Discord 适配器",
id=self.config.get("id"),
default_config_tmpl=self.config,
support_streaming_message=False,
)
@override

View File

@@ -1,7 +1,6 @@
import asyncio
import base64
import binascii
from collections.abc import AsyncGenerator
import sys
from io import BytesIO
from pathlib import Path
@@ -21,6 +20,11 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata
from .client import DiscordBotClient
from .components import DiscordEmbed, DiscordView
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 自定义Discord视图组件兼容旧版本
class DiscordViewComponent(BaseMessageComponent):
@@ -44,6 +48,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.client = client
self.interaction_followup_webhook = interaction_followup_webhook
@override
async def send(self, message: MessageChain):
"""发送消息到Discord平台"""
# 解析消息链为 Discord 所需的对象
@@ -92,21 +97,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None:
"""获取当前事件对应的频道对象"""
try:
@@ -193,7 +183,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
BytesIO(img_bytes),
filename=filename or "image.png",
)
except (ValueError, TypeError, binascii.Error):
except (ValueError, TypeError, base64.binascii.Error):
logger.debug(
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}",
)

View File

@@ -23,9 +23,7 @@ from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent
@register_platform_adapter(
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
)
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
class LarkPlatformAdapter(Platform):
def __init__(
self,
@@ -117,7 +115,6 @@ class LarkPlatformAdapter(Platform):
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):

View File

@@ -45,9 +45,7 @@ MAX_FILE_UPLOAD_COUNT = 16
DEFAULT_UPLOAD_CONCURRENCY = 3
@register_platform_adapter(
"misskey", "Misskey 平台适配器", support_streaming_message=False
)
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self,
@@ -122,7 +120,6 @@ class MisskeyPlatformAdapter(Platform):
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
support_streaming_message=False,
)
async def run(self):

View File

@@ -29,7 +29,8 @@ from astrbot.core.platform.astr_message_event import MessageSession
@register_platform_adapter(
"satori", "Satori 协议适配器", support_streaming_message=False
"satori",
"Satori 协议适配器",
)
class SatoriPlatformAdapter(Platform):
def __init__(
@@ -59,7 +60,6 @@ class SatoriPlatformAdapter(Platform):
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
support_streaming_message=False,
)
self.ws: ClientConnection | None = None

View File

@@ -30,7 +30,6 @@ from .slack_event import SlackMessageEvent
@register_platform_adapter(
"slack",
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
support_streaming_message=False,
)
class SlackAdapter(Platform):
def __init__(
@@ -69,7 +68,6 @@ class SlackAdapter(Platform):
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
support_streaming_message=False,
)
# 初始化 Slack Web Client
@@ -84,7 +82,7 @@ class SlackAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
):
blocks, text = await SlackMessageEvent._parse_slack_blocks(
blocks, text = SlackMessageEvent._parse_slack_blocks(
message_chain=message_chain,
web_client=self.web_client,
)

View File

@@ -5,6 +5,8 @@ import uuid
from collections.abc import Awaitable, Callable
from typing import Any
from astrbot_api.abc import IAstrbotPaths
from astrbot import logger
from astrbot.core.message.components import Image, Plain, Record
from astrbot.core.message.message_event_result import MessageChain
@@ -16,12 +18,13 @@ from astrbot.core.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_sdk import sync_base_container
from ...register import register_platform_adapter
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
class QueueListener:
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
@@ -79,7 +82,7 @@ class WebChatAdapter(Platform):
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
self.imgs_dir = str(AstrbotPaths.astrbot_root / "webchat" / "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata(
@@ -163,9 +166,6 @@ class WebChatAdapter(Platform):
_, _, payload = message.raw_message # type: ignore
message_event.set_extra("selected_provider", payload.get("selected_provider"))
message_event.set_extra("selected_model", payload.get("selected_model"))
message_event.set_extra(
"enable_streaming", payload.get("enable_streaming", True)
)
self.commit_event(message_event)

View File

@@ -109,7 +109,6 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
reasoning_content = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
@@ -125,22 +124,16 @@ class WebChatMessageEvent(AstrMessageEvent):
)
final_data = ""
continue
r = await WebChatMessageEvent._send(
final_data += await WebChatMessageEvent._send(
chain,
session_id=self.session_id,
streaming=True,
)
if chain.type == "reasoning":
reasoning_content += chain.get_plain_text()
else:
final_data += r
await web_chat_back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
"reasoning": reasoning_content,
"streaming": True,
"cid": cid,
},

View File

@@ -8,10 +8,14 @@ import traceback
import aiohttp
import anyio
import websockets
from astrbot_api.abc import IAstrbotPaths
from astrbot import logger
from astrbot.api.message_components import At, Image, Plain, Record
from astrbot.api.platform import Platform, PlatformMetadata
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.astrbot_message import (
@@ -19,7 +23,6 @@ from astrbot.core.platform.astrbot_message import (
MessageMember,
MessageType,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
from .wechatpadpro_message_event import WeChatPadProMessageEvent
@@ -32,9 +35,7 @@ except ImportError as e:
)
@register_platform_adapter(
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
)
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
class WeChatPadProAdapter(Platform):
def __init__(
self,
@@ -53,7 +54,6 @@ class WeChatPadProAdapter(Platform):
name="wechatpadpro",
description="WeChatPadPro 消息平台适配器",
id=self.config.get("id", "wechatpadpro"),
support_streaming_message=False,
)
# 保存配置信息
@@ -71,9 +71,8 @@ class WeChatPadProAdapter(Platform):
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
self.credentials_file = str(
AstrbotPaths.astrbot_root / "wechatpadpro_credentials.json"
) # 持久化文件路径
self.ws_handle_task = None
@@ -158,8 +157,8 @@ class WeChatPadProAdapter(Platform):
}
try:
# 确保数据目录存在
data_dir = os.path.dirname(self.credentials_file)
os.makedirs(data_dir, exist_ok=True)
config_dir = AstrbotPaths.astrbot_root / "config"
config_dir.mkdir(parents=True, exist_ok=True)
with open(self.credentials_file, "w") as f:
json.dump(credentials, f)
except Exception as e:
@@ -790,10 +789,10 @@ class WeChatPadProAdapter(Platform):
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir,
f"wechatpadpro_voice_{abm.message_id}.silk",
file_path = str(
AstrbotPaths.astrbot_root
/ "temp"
/ f"wechatpadpro_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:

View File

@@ -1,7 +1,6 @@
import asyncio
import base64
import io
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
import aiohttp
@@ -51,21 +50,6 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
await self._send_voice(session, comp)
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
b64 = await comp.convert_to_base64()
raw = self._validate_base64(b64)

View File

@@ -110,7 +110,7 @@ class WecomServer:
await self.shutdown_event.wait()
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
@register_platform_adapter("wecom", "wecom 适配器")
class WecomPlatformAdapter(Platform):
def __init__(
self,
@@ -196,7 +196,6 @@ class WecomPlatformAdapter(Platform):
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
support_streaming_message=False,
)
@override

View File

@@ -10,7 +10,7 @@ import base64
import hashlib
import json
import logging
import secrets
import random
import socket
import struct
import time
@@ -139,12 +139,6 @@ class PKCS7Encoder:
class Prpcrypt:
"""提供接收和推送给企业微信消息的加解密接口"""
# 16位随机字符串的范围常量
# randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999]两端都包含即包含0和8999999999999999
# 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999]两端都包含即16位数字
MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位)
RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位)
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
@@ -213,9 +207,7 @@ class Prpcrypt:
"""随机生成16位字符串
@return: 16位字符串
"""
return str(
secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE
).encode()
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizJsonMsgCrypt:

View File

@@ -5,7 +5,7 @@
import asyncio
import base64
import hashlib
import secrets
import random
import string
from typing import Any
@@ -53,7 +53,7 @@ def generate_random_string(length: int = 10) -> str:
"""
letters = string.ascii_letters + string.digits
return "".join(secrets.choice(letters) for _ in range(length))
return "".join(random.choice(letters) for _ in range(length))
def calculate_image_md5(image_data: bytes) -> str:

View File

@@ -113,9 +113,7 @@ class WecomServer:
await self.shutdown_event.wait()
@register_platform_adapter(
"weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
)
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
class WeixinOfficialAccountPlatformAdapter(Platform):
def __init__(
self,
@@ -197,7 +195,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
support_streaming_message=False,
)
@override

View File

@@ -1,4 +1,4 @@
from .entities import ProviderMetaData
from .provider import Provider, STTProvider
from .provider import Personality, Provider, STTProvider
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]

View File

@@ -30,31 +30,18 @@ class ProviderType(enum.Enum):
@dataclass
class ProviderMeta:
"""The basic metadata of a provider instance."""
id: str
"""the unique id of the provider instance that user configured"""
model: str | None
"""the model name of the provider instance currently used"""
class ProviderMetaData:
type: str
"""the name of the provider adapter, such as openai, ollama"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
"""the capability type of the provider adapter"""
@dataclass
class ProviderMetaData(ProviderMeta):
"""The metadata of a provider adapter for registration."""
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""the short description of the provider adapter"""
"""提供商适配器描述"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Any = None
"""the class type of the provider adapter"""
default_config_tmpl: dict | None = None
"""the default configuration template of the provider adapter"""
"""平台的默认配置模板"""
provider_display_name: str | None = None
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
@@ -73,20 +60,12 @@ class ToolCallsResult:
]
return ret
def to_openai_messages_model(
self,
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
return [
self.tool_calls_info,
*self.tool_calls_result,
]
@dataclass
class ProviderRequest:
prompt: str | None = None
prompt: str
"""提示词"""
session_id: str | None = ""
session_id: str = ""
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
@@ -202,28 +181,25 @@ class ProviderRequest:
@dataclass
class LLMResponse:
role: str
"""The role of the message, e.g., assistant, tool, err"""
"""角色, assistant, tool, err"""
result_chain: MessageChain | None = None
"""A chain of message components representing the text completion from LLM."""
"""返回的消息链"""
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
"""Tool call arguments."""
"""工具调用参数"""
tools_call_name: list[str] = field(default_factory=list)
"""Tool call names."""
"""工具调用名称"""
tools_call_ids: list[str] = field(default_factory=list)
"""Tool call IDs."""
reasoning_content: str = ""
"""The reasoning content extracted from the LLM, if any."""
"""工具调用 ID"""
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
) = None
"""The raw completion response from the LLM provider."""
_new_record: dict[str, Any] | None = None
_completion_text: str = ""
"""The plain text of the completion."""
is_chunk: bool = False
"""Indicates if the response is a chunked response."""
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
@@ -237,6 +213,7 @@ class LLMResponse:
| GenerateContentResponse
| AnthropicMessage
| None = None,
_new_record: dict[str, Any] | None = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -264,6 +241,7 @@ class LLMResponse:
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
@@ -25,16 +24,7 @@ SUPPORTED_TYPES = [
"boolean",
] # json schema 支持的数据类型
PY_TO_JSON_TYPE = {
"int": "number",
"float": "number",
"bool": "boolean",
"str": "string",
"dict": "object",
"list": "array",
"tuple": "array",
"set": "array",
}
# alias
FuncTool = FunctionTool
@@ -116,7 +106,7 @@ class FunctionToolManager:
def spec_to_func(
self,
name: str,
func_args: list[dict],
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
@@ -125,9 +115,10 @@ class FunctionToolManager:
"properties": {},
}
for param in func_args:
p = copy.deepcopy(param)
p.pop("name", None)
params["properties"][param["name"]] = p
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
return FuncTool(
name=name,
parameters=params,

View File

@@ -241,8 +241,6 @@ class ProviderManager:
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
@@ -356,8 +354,6 @@ class ProviderManager:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
@@ -398,6 +394,7 @@ class ProviderManager:
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
)
if getattr(inst, "initialize", None):

View File

@@ -1,18 +1,28 @@
import abc
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Personality
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMeta,
ProviderType,
RerankResult,
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
@dataclass
class ProviderMeta:
id: str
model: str
type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
@@ -33,15 +43,15 @@ class AbstractProvider(abc.ABC):
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
if not meta_data:
raise ValueError(f"Provider type {provider_type_name} not registered")
meta = ProviderMeta(
id=self.provider_config.get("id", "default"),
provider_type = meta_data.provider_type if meta_data else None
if provider_type is None:
raise ValueError(f"Cannot find provider type: {provider_type_name}")
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
type=provider_type_name,
provider_type=meta_data.provider_type,
provider_type=provider_type,
)
return meta
class Provider(AbstractProvider):
@@ -51,10 +61,15 @@ class Provider(AbstractProvider):
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality = default_persona
"""维护了当前的使用的 persona即人格。可能为 None"""
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError

View File

@@ -36,8 +36,6 @@ def register_provider_adapter(
default_config_tmpl["id"] = provider_type_name
pm = ProviderMetaData(
id="default", # will be replaced when instantiated
model=None,
type=provider_type_name,
desc=desc,
provider_type=provider_type,

View File

@@ -25,10 +25,12 @@ class ProviderAnthropic(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.chosen_api_key: str = ""

View File

@@ -1,8 +1,8 @@
import asyncio
import hashlib
import json
import random
import re
import secrets
import time
import uuid
from pathlib import Path
@@ -54,9 +54,7 @@ class OTTSProvider:
async def _generate_signature(self) -> str:
await self._sync_time()
timestamp = int(time.time()) + self.time_offset
nonce = "".join(
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
)
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"

View File

@@ -20,10 +20,12 @@ class ProviderCoze(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:

View File

@@ -8,7 +8,7 @@ from dashscope.app.application_response import ApplicationResponse
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from .. import Provider
from .. import Personality, Provider
from ..entities import LLMResponse
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@@ -20,11 +20,13 @@ class ProviderDashscope(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:

View File

@@ -15,7 +15,11 @@ except (
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -45,7 +49,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):

View File

@@ -18,10 +18,12 @@ class ProviderDify(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:

View File

@@ -4,9 +4,12 @@ import subprocess
import uuid
import edge_tts
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -46,8 +49,10 @@ class ProviderEdgeTTS(TTSProvider):
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
mp3_path = str(
AstrbotPaths.astrbot_root / "temp" / f"edge_tts_temp_{uuid.uuid4()}.mp3"
)
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
# 构建 Edge TTS 参数

View File

@@ -53,10 +53,12 @@ class ProviderGoogleGenAI(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_keys: list = super().get_keys()
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
@@ -324,18 +326,8 @@ class ProviderGoogleGenAI(Provider):
return gemini_contents
def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
"""Extract reasoning content from candidate parts"""
if not candidate.content or not candidate.content.parts:
return ""
thought_buf: list[str] = [
(p.text or "") for p in candidate.content.parts if p.thought
]
return "".join(thought_buf).strip()
@staticmethod
def _process_content_parts(
self,
candidate: types.Candidate,
llm_response: LLMResponse,
) -> MessageChain:
@@ -366,11 +358,6 @@ class ProviderGoogleGenAI(Provider):
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
raise Exception("API 返回的 candidate.content.parts 为空。")
# 提取 reasoning content
reasoning = self._extract_reasoning_content(candidate)
if reasoning:
llm_response.reasoning_content = reasoning
chain = []
part: types.Part
@@ -528,7 +515,6 @@ class ProviderGoogleGenAI(Provider):
# Accumulate the complete response text for the final response
accumulated_text = ""
accumulated_reasoning = ""
final_response = None
async for chunk in result:
@@ -553,19 +539,9 @@ class ProviderGoogleGenAI(Provider):
yield llm_response
return
_f = False
# 提取 reasoning content
reasoning = self._extract_reasoning_content(chunk.candidates[0])
if reasoning:
_f = True
accumulated_reasoning += reasoning
llm_response.reasoning_content = reasoning
if chunk.text:
_f = True
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
if _f:
yield llm_response
if chunk.candidates[0].finish_reason:
@@ -583,10 +559,6 @@ class ProviderGoogleGenAI(Provider):
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
# Set the complete accumulated reasoning in the final response
if accumulated_reasoning:
final_response.reasoning_content = accumulated_reasoning
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(

View File

@@ -1,15 +0,0 @@
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter(
"groq_chat_completion", "Groq Chat Completion Provider Adapter"
)
class ProviderGroq(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.reasoning_key = "reasoning"

View File

@@ -4,14 +4,12 @@ import inspect
import json
import os
import random
import re
from collections.abc import AsyncGenerator
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
import astrbot.core.message.components as Comp
from astrbot import logger
@@ -30,37 +28,37 @@ from ..register import register_provider_adapter
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
def __init__(self, provider_config, provider_settings) -> None:
super().__init__(provider_config, provider_settings)
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.chosen_api_key = None
self.api_keys: list = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
self.custom_headers = provider_config.get("custom_headers", {})
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
if not isinstance(self.custom_headers, dict) or not self.custom_headers:
self.custom_headers = None
else:
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])
# 适配 azure openai #332
if "api_version" in provider_config:
# Using Azure OpenAI API
# 使用 azure api
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
default_headers=self.custom_headers,
base_url=provider_config.get("api_base", ""),
timeout=self.timeout,
)
else:
# Using OpenAI Official API
# 使用 openai api
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
default_headers=self.custom_headers,
timeout=self.timeout,
)
@@ -72,8 +70,6 @@ class ProviderOpenAIOfficial(Provider):
model = model_config.get("model", "unknown")
self.set_model(model)
self.reasoning_key = "reasoning_content"
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
@@ -151,7 +147,7 @@ class ProviderOpenAIOfficial(Provider):
logger.debug(f"completion: {completion}")
llm_response = await self._parse_openai_completion(completion, tools)
llm_response = await self.parse_openai_completion(completion, tools)
return llm_response
@@ -204,78 +200,36 @@ class ProviderOpenAIOfficial(Provider):
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
# logger.debug(f"chunk delta: {delta}")
# handle the content delta
reasoning = self._extract_reasoning_content(chunk)
_y = False
if reasoning:
llm_response.reasoning_content = reasoning
_y = True
# 处理文本内容
if delta.content:
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
_y = True
if _y:
yield llm_response
final_completion = state.get_final_completion()
llm_response = await self._parse_openai_completion(final_completion, tools)
llm_response = await self.parse_openai_completion(final_completion, tools)
yield llm_response
def _extract_reasoning_content(
self,
completion: ChatCompletion | ChatCompletionChunk,
) -> str:
"""Extract reasoning content from OpenAI ChatCompletion if available."""
reasoning_text = ""
if len(completion.choices) == 0:
return reasoning_text
if isinstance(completion, ChatCompletion):
choice = completion.choices[0]
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
elif isinstance(completion, ChatCompletionChunk):
delta = completion.choices[0].delta
reasoning_attr = getattr(delta, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
return reasoning_text
async def _parse_openai_completion(
async def parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
"""Parse OpenAI ChatCompletion into LLMResponse"""
"""解析 OpenAI ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
# parse the text completion
if choice.message.content is not None:
# text completion
completion_text = str(choice.message.content).strip()
# specially, some providers may set <think> tags around reasoning content in the completion text,
# we use regex to remove them, and store then in reasoning_content field
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
matches = reasoning_pattern.findall(completion_text)
if matches:
llm_response.reasoning_content = "\n".join(
[match.strip() for match in matches],
)
completion_text = reasoning_pattern.sub("", completion_text).strip()
llm_response.result_chain = MessageChain().message(completion_text)
# parse the reasoning content if any
# the priority is higher than the <think> tag extraction
llm_response.reasoning_content = self._extract_reasoning_content(completion)
# parse tool calls if any
if choice.message.tool_calls and tools is not None:
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_call_ids = []
@@ -301,11 +255,11 @@ class ProviderOpenAIOfficial(Provider):
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_call_ids
# specially handle finish reason
if choice.finish_reason == "content_filter":
raise Exception(
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
)
if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")

View File

@@ -1,13 +1,16 @@
import asyncio
import base64
import json
import os
import traceback
import uuid
import aiohttp
from astrbot_api.abc import IAstrbotPaths
from astrbot import logger
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -92,9 +95,10 @@ class ProviderVolcengineTTS(TTSProvider):
if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"])
os.makedirs("data/temp", exist_ok=True)
temp_dir = AstrbotPaths.astrbot_root / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
file_path = str(temp_dir / f"volcengine_tts_{uuid.uuid4()}.mp3")
loop = asyncio.get_running_loop()
await loop.run_in_executor(

View File

@@ -1,10 +1,13 @@
import os
import uuid
from astrbot_api.abc import IAstrbotPaths
from openai import NOT_GIVEN, AsyncOpenAI
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@@ -53,8 +56,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
is_tencent = True
name = str(uuid.uuid4())
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, name)
path = str(AstrbotPaths.astrbot_root / "temp" / name)
await download_file(audio_url, path)
audio_url = path
@@ -65,8 +67,9 @@ class ProviderOpenAIWhisperAPI(STTProvider):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
output_path = str(
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.wav"
)
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path

View File

@@ -12,5 +12,10 @@ class ProviderZhipu(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
default_persona=None,
) -> None:
super().__init__(provider_config, provider_settings)
super().__init__(
provider_config,
provider_settings,
default_persona,
)

View File

@@ -3,7 +3,11 @@
import json
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
def load_config(namespace: str) -> dict | bool:
@@ -11,7 +15,7 @@ def load_config(namespace: str) -> dict | bool:
namespace: str, 配置的唯一识别符,也就是配置文件的名字。
返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict否则返回 False。
"""
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
path = str(AstrbotPaths.astrbot_root / "config" / f"{namespace}.json")
if not os.path.exists(path):
return False
with open(path, encoding="utf-8-sig") as f:
@@ -41,8 +45,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str):
if not isinstance(value, (str, int, float, bool, list)):
raise ValueError("value 只支持 str, int, float, bool, list 类型。")
config_dir = os.path.join(get_astrbot_data_path(), "config")
path = os.path.join(config_dir, f"{namespace}.json")
path = str(AstrbotPaths.astrbot_root / "config" / f"{namespace}.json")
if not os.path.exists(path):
with open(path, "w", encoding="utf-8-sig") as f:
@@ -70,7 +73,7 @@ def update_config(namespace: str, key: str, value):
key: str, 配置项的键。
value: str, int, float, bool, list, 配置项的值。
"""
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
path = str(AstrbotPaths.astrbot_root / "config" / f"{namespace}.json")
if not os.path.exists(path):
raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。")
with open(path, encoding="utf-8-sig") as f:

View File

@@ -5,10 +5,6 @@ from typing import Any
from deprecated import deprecated
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.conversation_mgr import ConversationManager
@@ -17,10 +13,10 @@ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.platform import Platform
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
@@ -35,7 +31,6 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
)
from ..exceptions import ProviderNotFoundError
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from .star import StarMetadata, star_map, star_registry
@@ -80,153 +75,6 @@ class Context:
self.astrbot_config_mgr = astrbot_config_mgr
self.kb_manager = knowledge_base_manager
async def llm_generate(
self,
*,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
tools: ToolSet of tools available to the LLM
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
Raises:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
llm_resp = await prov.text_chat(
prompt=prompt,
image_urls=image_urls,
func_tool=tools,
contexts=contexts,
system_prompt=system_prompt,
**kwargs,
)
return llm_resp
async def tool_loop_agent(
self,
*,
event: AstrMessageEvent,
chat_provider_id: str,
prompt: str | None = None,
image_urls: list[str] | None = None,
tools: ToolSet | None = None,
system_prompt: str | None = None,
contexts: list[Message] | None = None,
max_steps: int = 30,
tool_call_timeout: int = 60,
**kwargs: Any,
) -> LLMResponse:
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
If you do not pass the agent_context parameter, the method will recreate a new agent context.
.. versionadded:: 4.5.7 (sdk)
Args:
chat_provider_id: The chat provider ID to use.
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
tools: ToolSet of tools available to the LLM
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
Returns:
The final LLMResponse after tool calls are completed.
Raises:
ChatProviderNotFoundError: If the specified chat provider ID is not found
Exception: For other errors during LLM generation
"""
# Import here to avoid circular imports
from astrbot.core.astr_agent_context import (
AgentContextWrapper,
AstrAgentContext,
)
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
if not prov or not isinstance(prov, Provider):
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
agent_context = kwargs.get("agent_context")
context_ = []
for msg in contexts or []:
if isinstance(msg, Message):
context_.append(msg.model_dump())
else:
context_.append(msg)
request = ProviderRequest(
prompt=prompt,
image_urls=image_urls or [],
func_tool=tools,
contexts=context_,
system_prompt=system_prompt or "",
)
if agent_context is None:
agent_context = AstrAgentContext(
context=self,
event=event,
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
await agent_runner.reset(
provider=prov,
request=request,
run_context=AgentContextWrapper(
context=agent_context,
tool_call_timeout=tool_call_timeout,
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
)
async for _ in agent_runner.step_until_done(max_steps):
pass
llm_resp = agent_runner.get_final_llm_resp()
if not llm_resp:
raise Exception("Agent did not produce a final LLM response")
return llm_resp
async def get_current_chat_provider_id(self, umo: str) -> str:
"""Get the ID of the currently used chat provider.
Args:
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
Raises:
ProviderNotFoundError: If the specified chat provider is not found
"""
prov = self.get_using_provider(umo)
if not prov:
raise ProviderNotFoundError("Provider not found")
return prov.meta().id
def get_registered_star(self, star_name: str) -> StarMetadata | None:
"""根据插件名获取插件的 Metadata"""
for star in star_registry:
@@ -259,6 +107,10 @@ class Context:
"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(
self,
provider_id: str,
@@ -337,6 +189,45 @@ class Context:
return self._config
return self.astrbot_config_mgr.get_conf(umo)
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
async def send_message(
self,
session: str | MessageSesion,
@@ -409,49 +300,6 @@ class Context:
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def register_llm_tool(
self,
name: str,

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import re
from collections.abc import Awaitable, Callable
from typing import Any
@@ -12,7 +11,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from ..filter.command import CommandFilter
@@ -418,37 +417,18 @@ def register_llm_tool(name: str | None = None, **kwargs):
docstring = docstring_parser.parse(func_doc)
args = []
for arg in docstring.params:
sub_type_name = None
type_name = arg.type_name
if not type_name:
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。",
)
# parse type_name to handle cases like "list[string]"
match = re.match(r"(\w+)\[(\w+)\]", type_name)
if match:
type_name = match.group(1)
sub_type_name = match.group(2)
type_name = PY_TO_JSON_TYPE.get(type_name, type_name)
if sub_type_name:
sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name)
if type_name not in SUPPORTED_TYPES or (
sub_type_name and sub_type_name not in SUPPORTED_TYPES
):
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}",
)
arg_json_schema = {
"type": type_name,
"name": arg.arg_name,
"description": arg.description,
}
if sub_type_name:
if type_name == "array":
arg_json_schema["items"] = {"type": sub_type_name}
args.append(arg_json_schema)
args.append(
{
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description,
},
)
# print(llm_tool_name, registering_agent)
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)

View File

@@ -680,18 +680,11 @@ class PluginManager:
return plugin_info
async def uninstall_plugin(
self,
plugin_name: str,
delete_config: bool = False,
delete_data: bool = False,
):
async def uninstall_plugin(self, plugin_name: str):
"""卸载指定的插件。
Args:
plugin_name (str): 要卸载的插件名称
delete_config (bool): 是否删除插件配置文件,默认为 False
delete_data (bool): 是否删除插件数据,默认为 False
Raises:
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
@@ -721,7 +714,6 @@ class PluginManager:
await self._unbind_plugin(plugin_name, plugin.module_path)
# 删除插件文件夹
try:
remove_dir(os.path.join(ppath, root_dir_name))
except Exception as e:
@@ -729,51 +721,6 @@ class PluginManager:
f"移除插件成功,但是删除插件文件夹失败: {e!s}。您可以手动删除该文件夹,位于 addons/plugins/ 下。",
)
# 删除插件配置文件
if delete_config and root_dir_name:
config_file = os.path.join(
self.plugin_config_path,
f"{root_dir_name}_config.json",
)
if os.path.exists(config_file):
try:
os.remove(config_file)
logger.info(f"已删除插件 {plugin_name} 的配置文件")
except Exception as e:
logger.warning(f"删除插件配置文件失败: {e!s}")
# 删除插件持久化数据
# 注意需要检查两个可能的目录名plugin_data 和 plugins_data
# data/temp 目录可能被多个插件共享,不自动删除以防误删
if delete_data and root_dir_name:
data_base_dir = os.path.dirname(ppath) # data/
# 删除 data/plugin_data 下的插件持久化数据(单数形式,新版本)
plugin_data_dir = os.path.join(
data_base_dir, "plugin_data", root_dir_name
)
if os.path.exists(plugin_data_dir):
try:
remove_dir(plugin_data_dir)
logger.info(
f"已删除插件 {plugin_name} 的持久化数据 (plugin_data)"
)
except Exception as e:
logger.warning(f"删除插件持久化数据失败 (plugin_data): {e!s}")
# 删除 data/plugins_data 下的插件持久化数据(复数形式,旧版本兼容)
plugins_data_dir = os.path.join(
data_base_dir, "plugins_data", root_dir_name
)
if os.path.exists(plugins_data_dir):
try:
remove_dir(plugins_data_dir)
logger.info(
f"已删除插件 {plugin_name} 的持久化数据 (plugins_data)"
)
except Exception as e:
logger.warning(f"删除插件持久化数据失败 (plugins_data): {e!s}")
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
"""解绑并移除一个插件。

View File

@@ -6,7 +6,7 @@ import psutil
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.utils.astrbot_path import get_astrbot_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from .zip_updator import ReleaseInfo, RepoZipUpdator
@@ -20,7 +20,7 @@ class AstrBotUpdator(RepoZipUpdator):
def __init__(self, repo_mirror: str = "") -> None:
super().__init__(repo_mirror)
self.MAIN_PATH = get_astrbot_path()
self.astrbot_root = get_astrbot_data_path()
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
def terminate_child_processes(self):
@@ -117,7 +117,7 @@ class AstrBotUpdator(RepoZipUpdator):
try:
await download_file(file_url, "temp.zip")
logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...")
self.unzip_file("temp.zip", self.MAIN_PATH)
self.unzip_file("temp.zip", self.astrbot_root)
except BaseException as e:
raise e

View File

@@ -1,39 +1,70 @@
"""Astrbot统一路径获取
项目路径:固定为源码所在路径
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
数据目录路径:固定为根目录下的 data 目录
配置文件路径:固定为数据目录下的 config 目录
插件目录路径:固定为数据目录下的 plugins 目录
"""
import os
import warnings
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths = sync_base_container.get(type[IAstrbotPaths])
def get_astrbot_path() -> str:
"""获取Astrbot项目路径"""
"""获取Astrbot项目路径 -- 请勿使用本函数!!! -- 仅供兼容旧代码使用"""
warnings.warn(
"get_astrbot_path is deprecated. Use AstrbotPaths class instead.",
DeprecationWarning,
stacklevel=2,
)
return os.path.realpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"),
)
def get_astrbot_root() -> str:
"""获取Astrbot根目录路径"""
if path := os.environ.get("ASTRBOT_ROOT"):
return os.path.realpath(path)
return os.path.realpath(os.getcwd())
"""获取Astrbot根目录路径 --> get_astrbot_data_path"""
warnings.warn(
"不要再使用本函数!等效于: AstrbotPaths.astrbot_root",
DeprecationWarning,
stacklevel=2,
)
return str(AstrbotPaths.astrbot_root)
def get_astrbot_data_path() -> str:
"""获取Astrbot数据目录路径"""
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
"""获取Astrbot数据目录路径
特别注意!
这里的data目录指的就是.astrbot根目录!
两者是等价的!
不要和AstrbotPaths.data混淆!
"""
warnings.warn(
"等效于: AstrbotPaths.astrbot_root",
DeprecationWarning,
stacklevel=2,
)
return str(AstrbotPaths.astrbot_root)
def get_astrbot_config_path() -> str:
"""获取Astrbot配置文件路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
warnings.warn(
"get_astrbot_config_path is deprecated. Use AstrbotPaths class instead.",
DeprecationWarning,
stacklevel=2,
)
return str(AstrbotPaths.astrbot_root / "config")
def get_astrbot_plugin_path() -> str:
"""获取Astrbot插件目录路径"""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
warnings.warn(
"get_astrbot_plugin_path is deprecated. Use AstrbotPaths class instead.",
DeprecationWarning,
stacklevel=2,
)
return str(AstrbotPaths.astrbot_root / "plugins")

View File

@@ -12,12 +12,16 @@ from pathlib import Path
import aiohttp
import certifi
import psutil
from astrbot_api.abc import IAstrbotPaths
from PIL import Image
from astrbot_sdk import sync_base_container
from .astrbot_path import get_astrbot_data_path
logger = logging.getLogger("astrbot")
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
def on_error(func, path, exc_info):
"""A callback of the rmtree function."""
@@ -50,11 +54,11 @@ def port_checker(port: int, host: str = "localhost"):
def save_temp_img(img: Image.Image | str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir(temp_dir):
path = os.path.join(temp_dir, f)
path = str(AstrbotPaths.astrbot_root / "temp" / f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600 * 12:
@@ -64,7 +68,7 @@ def save_temp_img(img: Image.Image | str) -> str:
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"{timestamp}.jpg")
p = str(AstrbotPaths.astrbot_root / "temp" / f"{timestamp}.jpg")
if isinstance(img, Image.Image):
img.save(p)
@@ -105,31 +109,16 @@ async def download_image_by_url(
f.write(await resp.read())
return path
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证仅在证书验证失败时作为fallback
logger.warning(
f"SSL certificate verification failed for {url}. "
"Disabling SSL verification (CERT_NONE) as a fallback. "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate and resolve certificate issues."
)
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_context.set_ciphers("DEFAULT")
async with aiohttp.ClientSession() as session:
if post:
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
else:
async with session.get(url, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
return save_temp_img(await resp.read())
except Exception as e:
raise e
@@ -172,19 +161,9 @@ async def download_file(url: str, path: str, show_progress: bool = False):
end="",
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证仅在证书验证失败时作为fallback
logger.warning(
"SSL 证书验证失败,已关闭 SSL 验证(不安全,仅用于临时下载)。请检查目标服务器的证书配置。"
)
logger.warning(
f"SSL certificate verification failed for {url}. "
"Falling back to unverified connection (CERT_NONE). "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate certificate issues with the remote server."
)
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_context.set_ciphers("DEFAULT")
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
@@ -230,9 +209,9 @@ def get_local_ip_addresses():
async def get_dashboard_version():
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
dist_dir = str(AstrbotPaths.astrbot_root / "dist")
if os.path.exists(dist_dir):
version_file = os.path.join(dist_dir, "assets", "version")
version_file = str(AstrbotPaths.astrbot_root / "dist" / "assets" / "version")
if os.path.exists(version_file):
with open(version_file, encoding="utf-8") as f:
v = f.read().strip()

View File

@@ -2,8 +2,9 @@
import os
import shutil
from importlib import resources
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class TemplateManager:
@@ -15,14 +16,10 @@ class TemplateManager:
CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]
def __init__(self):
self.builtin_template_dir = os.path.join(
get_astrbot_path(),
"astrbot",
"core",
"utils",
"t2i",
"template",
self.builtin_template_dir = str(
resources.files("astrbot.core.utils.t2i.template")
)
self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")
os.makedirs(self.user_template_dir, exist_ok=True)

View File

@@ -4,15 +4,18 @@ import os
import uuid
from contextlib import asynccontextmanager
from astrbot_api.abc import IAstrbotPaths
from quart import Response as QuartResponse
from quart import g, make_response, request
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .route import Response, Route, RouteContext
@@ -47,7 +50,7 @@ class ChatRoute(Route):
}
self.core_lifecycle = core_lifecycle
self.register_routes()
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
self.imgs_dir = str(AstrbotPaths.astrbot_root / "webchat" / "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
@@ -125,8 +128,6 @@ class ChatRoute(Route):
audio_url = post_data.get("audio_url")
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True) # 默认为 True
if not message and not image_url and not audio_url:
return (
Response()
@@ -204,8 +205,6 @@ class ChatRoute(Route):
):
# 追加机器人消息
new_his = {"type": "bot", "message": result_text}
if "reasoning" in result:
new_his["reasoning"] = result["reasoning"]
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
@@ -228,7 +227,6 @@ class ChatRoute(Route):
"audio_url": audio_url,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
},
),
)

View File

@@ -1,4 +1,5 @@
import asyncio
import importlib.resources
import inspect
import os
import traceback
@@ -21,7 +22,6 @@ from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core.utils.astrbot_path import get_astrbot_path
from .route import Response, Route, RouteContext
@@ -461,11 +461,12 @@ class ConfigRoute(Route):
logger.debug(
f"Sending health check audio to provider: {status_info['name']}",
)
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
sample_audio_path = str(
importlib.resources.files("astrbot")
/ "samples"
/ "stt_health_check.wav"
)
if not os.path.exists(sample_audio_path):
status_info["status"] = "unavailable"
status_info["error"] = (

View File

@@ -395,15 +395,9 @@ class PluginRoute(Route):
post_data = await request.json
plugin_name = post_data["name"]
delete_config = post_data.get("delete_config", False)
delete_data = post_data.get("delete_data", False)
try:
logger.info(f"正在卸载插件 {plugin_name}")
await self.plugin_manager.uninstall_plugin(
plugin_name,
delete_config=delete_config,
delete_data=delete_data,
)
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:

View File

@@ -296,15 +296,7 @@ class ToolsRoute(Route):
"""获取所有注册的工具列表"""
try:
tools = self.tool_mgr.func_list
tools_dict = [
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"active": tool.active,
}
for tool in tools
]
tools_dict = [tool.__dict__() for tool in tools]
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())

View File

@@ -5,6 +5,7 @@ import socket
import jwt
import psutil
from astrbot_api.abc import IAstrbotPaths
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
@@ -12,8 +13,8 @@ from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import get_local_ip_addresses
from astrbot_sdk import sync_base_container
from .routes import *
from .routes.route import Response, RouteContext
@@ -22,7 +23,7 @@ from .routes.t2i import T2iRoute
APP: Quart = None
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
class AstrBotDashboard:
def __init__(
self,
@@ -38,9 +39,7 @@ class AstrBotDashboard:
if webui_dir and os.path.exists(webui_dir):
self.data_path = os.path.abspath(webui_dir)
else:
self.data_path = os.path.abspath(
os.path.join(get_astrbot_data_path(), "dist"),
)
self.data_path = os.path.abspath(str(AstrbotPaths.astrbot_root / "dist"))
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
APP = self.app # noqa

0
astrbot_api/README.md Normal file
View File

View File

@@ -0,0 +1,25 @@
[project]
name = "astrbot-api"
dynamic = ["version"]
description = "Astrbot Python API"
readme = "README.md"
authors = [
{ name = "AstrBotDevs", email = "community@astrbot.app" }
]
requires-python = ">=3.10"
dependencies = [
"anyio>=4.11.0",
"mcp>=1.12.4",
"pydantic>=2.10.6",
"pyyaml>=6.0.3",
"types-pyyaml>=6.0.12.20250915",
"yarl>=1.22.0",
]
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
[tool.hatch.version]
source = "uv-dynamic-versioning"

View File

@@ -0,0 +1,7 @@
from .abc import IAstrbotPaths
from .const import LOGO
__all__ = [
"IAstrbotPaths",
"LOGO",
]

View File

@@ -0,0 +1,379 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Callable, Generator
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
asynccontextmanager,
contextmanager,
)
from importlib.abc import Traversable
from pathlib import Path
from typing import ClassVar
import anyio
import tomli
import yaml
from yarl import URL
from .models import AstrbotPluginMetadata
# region 核心运行时协议
class IAstrbotContainerMgr(ABC):
"""AstrBot 容器管理器的抽象基类."""
class IAstrbotPaths(ABC):
"""路径管理的抽象基类."""
astrbot_root: ClassVar[Path]
"""Astrbot 根目录路径."""
@abstractmethod
def __init__(self, name: str) -> None:
"""初始化路径管理器."""
@classmethod
@abstractmethod
def getPaths(cls, name: str) -> IAstrbotPaths:
"""返回Paths实例,用于访问模块的各类目录."""
@property
@abstractmethod
def root(self) -> Path:
"""获取根目录."""
@property
@abstractmethod
def home(self) -> Path:
"""获取模块/插件主目录."""
@property
@abstractmethod
def config(self) -> Path:
"""获取模块配置目录."""
@property
@abstractmethod
def data(self) -> Path:
"""获取模块数据目录."""
@property
@abstractmethod
def log(self) -> Path:
"""获取模块日志目录."""
@property
@abstractmethod
def temp(self) -> Path:
"""获取模块临时目录."""
@property
@abstractmethod
def plugins(self) -> Path:
"""获取插件目录."""
@abstractmethod
def reload(self) -> None:
"""重新加载环境变量."""
@classmethod
@abstractmethod
def is_root(cls, path: Path) -> bool:
"""判断路径是否为根目录."""
@abstractmethod
def chdir(self, cwd: str = "home") -> AbstractContextManager[Path]:
"""临时切换到指定目录, 子进程将继承此 CWD。"""
@abstractmethod
async def achdir(self, cwd: str = "home") -> AbstractAsyncContextManager[Path]:
"""异步临时切换到指定目录, 子进程将继承此 CWD。"""
class AstrbotPluginBaseSession(ABC):
"""插件会话的基类."""
url: URL
"""插件的astrbot专有协议URL地址.
协议: astrbot://{stdio/web/legacy}/plugin_id
"""
@abstractmethod
def connect(self) -> AstrbotPluginBaseSession:
"""连接到插件."""
...
@abstractmethod
def disconnect(self) -> None:
"""断开与插件的连接."""
...
@abstractmethod
async def async_connect(self) -> AstrbotPluginBaseSession:
"""异步连接到插件."""
...
@abstractmethod
async def async_disconnect(self) -> None:
"""异步断开与插件的连接."""
...
@contextmanager
def session_scope(self) -> Generator[AstrbotPluginBaseSession]:
"""插件会话的上下文管理器."""
try:
yield self.connect()
finally:
self.disconnect()
@asynccontextmanager
async def async_session_scope(self) -> AsyncGenerator[AstrbotPluginBaseSession]:
"""异步插件会话的上下文管理器."""
try:
yield await self.async_connect()
finally:
await self.async_disconnect()
# region 数据传输方法
# 主动发送数据
@abstractmethod
def send(self, data: bytes) -> bytes:
"""发送数据到插件并接收响应."""
...
@abstractmethod
async def async_send(self, data: bytes) -> bytes:
"""异步发送数据到插件并接收响应."""
...
# 被动接收数据
@abstractmethod
def listen(self, callback: Callable[[bytes], bytes | None]) -> None:
"""监听插件发送的数据.
callback: 一个接受 bytes 类型参数并返回 bytes 或 None 的函数.
如果返回 None, 则不给插件发送响应.
"""
...
@abstractmethod
async def async_listen(self, callback: Callable[[bytes], bytes | None]) -> None:
"""异步监听插件发送的数据.
callback: 一个接受 bytes 类型参数并返回 bytes 或 None 的函数.
如果返回 None, 则不给插件发送响应.
"""
...
class IVirtualAstrbotPlugin(ABC):
"""AstrBot 虚拟插件的基类协议."""
vpo_map: ClassVar[dict[URL, IVirtualAstrbotPlugin]] = {}
"""虚拟插件对象映射表, key 是插件的 astrbot 协议 URL 地址."""
url: URL
"""AstrBot 插件的 astrbot 专有协议 URL 地址.
协议:
astrbot://{stdio/web/legacy}/plugin_id
"""
metadata: AstrbotPluginMetadata
"""插件元数据."""
session: AstrbotPluginBaseSession
"""插件会话对象."""
# region 公共方法
@classmethod
@abstractmethod
def handshake(cls) -> None:
"""通用插件握手方法.
1. 发送握手请求
2. 接受插件元数据响应,并设置 metadata 属性
3. 返回确认消息,表示握手成功
"""
...
# part1 :工厂方法
@classmethod
@abstractmethod
def fromFile(cls, path: Path, * , stdio: bool = False) -> IVirtualAstrbotPlugin:
"""从文件加载插件/插件包的公共方法.
通过此方法加载经典插件: stdio=False
或者子进程插件: stdio=True
任务:
1. 从 path 加载插件,调用私有方法 _load_metadata 加载插件元数据.
"""
...
@classmethod
@abstractmethod
def fromURL(cls, url: URL) -> IVirtualAstrbotPlugin:
"""从URL加载插件/插件包的公共方法."""
...
# part2 :实例方法
@abstractmethod
def get_logo(self) -> Traversable | None:
"""获取插件的logo文件路径(Optional).
请使用importlib.resources.files来访问并返回Traversable对象.
示例:
from importlib.resources import files
logo_path = files("plugin_a/assets/logo.png")
# 如果插件安装在虚拟环境,可以直接这样获取
# 如果插件不可以安装到虚拟环境,适用于子进程插件/网络插件,可以这样获取
返回None表示没有logo文件.
Returns:
Traversable | None: 插件logo文件的路径或None.
"""
...
@abstractmethod
def get_metadata(self) -> AstrbotPluginMetadata:
"""获取插件元数据的公共方法."""
...
@abstractmethod
def get_url(self) -> URL:
"""获取插件URL(astrbot协议)的公共方法."""
...
@abstractmethod
def get_session(self) -> AstrbotPluginBaseSession:
"""获取插件会话对象的公共方法."""
...
# region 魔术方法
@abstractmethod
def __str__(self) -> str:
"""返回插件元数据的字符串表示."""
...
@abstractmethod
def __repr__(self) -> str:
"""返回插件元数据的正式字符串表示."""
...
# region 私有方法
@classmethod
def _load_metadata(cls, path: Path) -> AstrbotPluginMetadata:
"""加载插件元数据的私有方法,自动按下列优先级加载: pyproject -> yaml -> toml -> json.
其中yaml: plugin.yaml > metadata.yaml
此函数是上述工厂方法的辅助函数,用于加载插件元数据.
"""
match path.suffix.lower():
case ".json":
return cls._load_metadata_json(path)
case ".toml":
return cls._load_metadata_toml(path)
case ".yaml" | ".yml":
return cls._load_metadata_yaml(path)
case _:
raise ValueError(f"不支持的插件元数据文件格式: {path.suffix}")
@classmethod
async def _load_metadata_async(cls, path: Path) -> AstrbotPluginMetadata:
"""异步加载插件元数据的私有方法,自动按下列优先级加载: pyproject -> yaml -> toml -> json.
其中yaml: plugin.yaml > metadata.yaml
此函数是上述工厂方法的辅助函数,用于加载插件元数据.
"""
return await anyio.to_thread.run_sync(cls._load_metadata, path) # type: ignore[attr-defined]
@classmethod
def _load_metadata_json(cls, path: Path) -> AstrbotPluginMetadata:
"""从json文件加载插件元数据的私有方法."""
with path.open(encoding="utf-8") as f:
data = json.load(f)
return AstrbotPluginMetadata.model_validate(**data)
@classmethod
async def _load_metadata_json_async(cls, path: Path) -> AstrbotPluginMetadata:
"""异步从json文件加载插件元数据的私有方法."""
async with await anyio.open_file(path, "r", encoding="utf-8") as f:
content = await f.read()
data = json.loads(content)
return AstrbotPluginMetadata.model_validate(**data)
@classmethod
def _load_metadata_toml(cls, path: Path) -> AstrbotPluginMetadata:
"""从toml文件加载插件元数据的私有方法."""
with path.open("rb") as f:
data = tomli.load(f)
return AstrbotPluginMetadata.model_validate(**data)
@classmethod
async def _load_metadata_toml_async(cls, path: Path) -> AstrbotPluginMetadata:
"""异步从toml文件加载插件元数据的私有方法."""
async with await anyio.open_file(path, "rb") as f:
content = await f.read()
content_str = content.decode("utf-8")
data = tomli.loads(content_str)
return AstrbotPluginMetadata.model_validate(**data)
@classmethod
def _load_metadata_yaml(cls, path: Path) -> AstrbotPluginMetadata:
"""从yaml文件加载插件元数据的私有方法."""
with path.open(encoding="utf-8") as f:
data = yaml.safe_load(f)
return AstrbotPluginMetadata.model_validate(**data)
@classmethod
async def _load_metadata_yaml_async(cls, path: Path) -> AstrbotPluginMetadata:
"""异步从yaml文件加载插件元数据的私有方法."""
async with await anyio.open_file(path, "r", encoding="utf-8") as f:
content = await f.read()
data = yaml.safe_load(content)
return AstrbotPluginMetadata.model_validate(**data)
# region 插件运行时协议
class IAstrbotPluginRuntime(ABC):
"""AstrBot 插件运行时的基类协议."""
@abstractmethod
def start(self) -> None:
"""启动插件运行时."""
...
@abstractmethod
def stop(self) -> None:
"""停止插件运行时."""
...

View File

@@ -0,0 +1,9 @@
LOGO = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""

View File

@@ -0,0 +1,11 @@
from enum import Enum
class PluginType(Enum):
"""插件大类."""
# 兼容保留
LEGACY = "legacy" # 经典插件,直接在主进程中运行 -- 逐渐淘汰
# 后两者为新插件机制
STDIO = "stdio" # 子进程 -- 进程间通信
WEB = "web" # 通过 HTTP/HTTPS 协议调用

View File

@@ -0,0 +1,3 @@
class AstrbotBaseError(Exception):
"""Base exception for Astrbot API errors."""

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from pydantic import BaseModel
class AstrbotPluginMetadata(BaseModel):
"""AstrBot 插件元数据模型."""
name: str | None = None
"""插件名"""
author: str | None = None
"""插件作者"""
desc: str | None = None
"""插件简介"""
version: str | None = None
"""插件版本"""
repo: str | None = None
"""插件仓库地址"""
reserved: bool = False
"""是否是 AstrBot 的保留插件"""
activated: bool = True
"""是否被激活"""
display_name: str | None = None
"""用于展示的插件名称"""

View File

@@ -0,0 +1 @@

View File

View File

0
astrbot_sdk/README.md Normal file
View File

View File

@@ -0,0 +1,25 @@
[project]
name = "astrbot-sdk"
dynamic = ["version"]
description = "Astrbot Python SDK"
readme = "README.md"
authors = [
{ name = "AstrBotDevs", email = "community@astrbot.app" }
]
requires-python = ">=3.10"
dependencies = [
"astrbot-api",
"dishka>=1.7.2",
"dotenv>=0.9.9",
]
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
[tool.hatch.version]
source = "uv-dynamic-versioning"
[tool.uv.sources]
astrbot-api = { workspace = true }

View File

@@ -0,0 +1,23 @@
from dishka import (
AsyncContainer,
Container,
Provider,
make_async_container,
make_container,
)
from .base import AstrbotBaseProvider
base_provider: Provider = AstrbotBaseProvider()
async_base_container: AsyncContainer = make_async_container(base_provider)
sync_base_container: Container = make_container(base_provider)
__all__ = [
"async_base_container",
"sync_base_container",
]

View File

@@ -0,0 +1,3 @@
# 基础类的实现
- 单向导出

View File

@@ -0,0 +1,3 @@
from .provider import AstrbotBaseProvider
__all__ = ["AstrbotBaseProvider"]

View File

@@ -0,0 +1,131 @@
from __future__ import annotations
from collections.abc import AsyncGenerator, Generator
from contextlib import asynccontextmanager, contextmanager
from os import chdir, getenv
from pathlib import Path
from typing import ClassVar
from dotenv import load_dotenv
from packaging.utils import NormalizedName, canonicalize_name
from astrbot_api import IAstrbotPaths
class AstrbotPaths(IAstrbotPaths):
"""统一化路径获取."""
load_dotenv()
astrbot_root: ClassVar[Path] = Path(
getenv("ASTRBOT_ROOT", Path.home() / ".astrbot")
).absolute()
_instances: ClassVar[dict[str, AstrbotPaths]] = {}
def __init__(self, name: str) -> None:
self.name: str = name
# 确保根目录存在
self.astrbot_root.mkdir(parents=True, exist_ok=True)
@classmethod
def getPaths(cls, name: str) -> AstrbotPaths:
"""返回Paths实例,用于访问模块的各类目录."""
normalized_name: NormalizedName = canonicalize_name(name)
if normalized_name in cls._instances:
return cls._instances[normalized_name]
instance: AstrbotPaths = cls(normalized_name)
instance.name = normalized_name
cls._instances[normalized_name] = instance
return instance
@property
def root(self) -> Path:
"""返回根目录."""
return (
self.astrbot_root if self.astrbot_root.exists() else Path.cwd() / ".astrbot"
)
@property
def home(self) -> Path:
"""模块/插件主目录.
通过此属性获取模块/插件主目录.
"""
my_home = self.astrbot_root / "home" / self.name
my_home.mkdir(parents=True, exist_ok=True)
return my_home
@property
def config(self) -> Path:
"""返回模块/插件配置目录.
搭配 astrbot_config 使用.
"""
config_path = self.astrbot_root / "config" / self.name
config_path.mkdir(parents=True, exist_ok=True)
return config_path
@property
def data(self) -> Path:
"""返回模块/插件数据目录."""
data_path = self.astrbot_root / "data" / self.name
data_path.mkdir(parents=True, exist_ok=True)
return data_path
@property
def log(self) -> Path:
"""返回模块日志目录."""
log_path = self.astrbot_root / "logs" / self.name
log_path.mkdir(parents=True, exist_ok=True)
return log_path
@property
def temp(self) -> Path:
"""返回模块临时文件目录."""
temp_path = self.astrbot_root / "temp" / self.name
temp_path.mkdir(parents=True, exist_ok=True)
return temp_path
@property
def plugins(self) -> Path:
"""返回插件目录."""
plugin_path = self.astrbot_root / "plugins" / self.name
plugin_path.mkdir(parents=True, exist_ok=True)
return plugin_path
@classmethod
def is_root(cls, path: Path) -> bool:
"""检查路径是否为 Astrbot 根目录."""
if not path.exists() or not path.is_dir():
return False
# 检查此目录内是是否包含.astrbot标记文件
return bool((path / ".astrbot").exists())
def reload(self) -> None:
"""重新加载环境变量."""
load_dotenv()
self.__class__.astrbot_root = Path(
getenv("ASTRBOT_ROOT", Path.home() / ".astrbot")
).absolute()
@contextmanager
def chdir(self, cwd: str = "home") -> Generator[Path]:
"""临时切换到指定目录, 子进程将继承此 CWD。"""
original_cwd = Path.cwd()
target_dir = self.root / cwd
try:
chdir(target_dir)
yield target_dir
finally:
chdir(original_cwd)
@asynccontextmanager
async def achdir(self, cwd: str = "home") -> AsyncGenerator[Path]:
"""异步上下文管理器: 临时切换到指定目录, 子进程将继承此 CWD。"""
original_cwd = Path.cwd()
target_dir = self.root / cwd
try:
chdir(target_dir)
yield target_dir
finally:
chdir(original_cwd)

Some files were not shown because too many files have changed in this diff Show More