Compare commits

..

15 Commits

Author SHA1 Message Date
LIghtJUNction
c56d2aeced 文档更新 2025-11-03 23:35:22 +08:00
LIghtJUNction
fb56ac9f47 依赖注入示例 2025-11-03 23:28:13 +08:00
LIghtJUNction
5719dbd8b2 Update abc.py 2025-11-03 20:46:32 +08:00
LIghtJUNction
5217450c8a init 2025-11-03 20:28:55 +08:00
LIghtJUNction
15aafa2043 Update astrbot/core/config/default.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-03 14:26:00 +08:00
Copilot
36a3b4d318 chore: 运行 ruff format 修复代码格式 (#3291)
* Initial plan

* chore: 运行 ruff format 修复代码格式

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 14:25:01 +08:00
LIghtJUNction
7518c4a057 Update astrbot/core/utils/astrbot_path.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-03 12:36:47 +08:00
LIghtJUNction
8e3bd92c09 Update astrbot/core/utils/astrbot_path.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-03 03:14:37 +08:00
Copilot
9e4fec8488 [WIP] Update path class and deprecate old functions (#3287)
* Initial plan

* feat: 为 AstrbotPaths 添加全面测试,覆盖率达到 100%

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 03:06:53 +08:00
Copilot
983264dc1a fix: 移除未使用的临时目录变量 (#3286)
* Initial plan

* fix: 移除未使用的临时目录变量 (ruff check --fix)

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 02:51:12 +08:00
Copilot
b3f23b23da chore: 使用 ruff format 格式化代码 (#3283)
* Initial plan

* chore: 使用 ruff format 格式化代码

Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
2025-11-03 02:31:47 +08:00
LIghtJUNction
212d8cc101 版本号统一,常量移动至base包 (#3281)
Co-authored-by: 赵天乐(tyler zhao) <189870321+tyler-ztl@users.noreply.github.com>
2025-11-03 00:02:21 +08:00
LIghtJUNction
36de3c541a Update astrbot/base/paths.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-11-02 22:57:05 +08:00
LIghtJUNction
1151677f13 fix/ 抽象基类装饰器顺序写反了
Co-Authored-By: 赵天乐(tyler zhao) <189870321+tyler-ztl@users.noreply.github.com>
2025-11-02 22:54:43 +08:00
LIghtJUNction
2ccb85d802 统一路径类,旧函数已弃用
.env文件保存统一的路径
静态文件一起打包而不是从上上上级目录找...

Co-Authored-By: 赵天乐(tyler zhao) <189870321+tyler-ztl@users.noreply.github.com>
2025-11-02 22:47:22 +08:00
255 changed files with 9377 additions and 13797 deletions

View File

@@ -1,9 +1,9 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# github actions
.git
# github acions
.github/
.*ignore
.git/
# User-specific stuff
.idea/
# Byte-compiled / optimized / DLL files
@@ -15,10 +15,10 @@ env/
venv*/
ENV/
.conda/
README*.md
dashboard/
data/
changelogs/
tests/
.ruff_cache/
.astrbot
astrbot.lock
.astrbot

1
.env Normal file
View File

@@ -0,0 +1 @@
ASTRBOT_ROOT = ./data

2
.env.example Normal file
View File

@@ -0,0 +1,2 @@
# ASTRBOT 数据目录
# ASTRBOT_ROOT = ./data

View File

@@ -16,7 +16,7 @@ body:
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?可以从 [此](https://plugins.astrbot.app) 右下角提交。
不熟悉 JSON ?可以从 [此](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
- type: textarea
id: plugin-info

View File

@@ -1,44 +1,46 @@
name: '🐛 Report Bug / 报告 Bug'
name: '🐛 报告 Bug'
title: '[Bug]'
description: Submit bug report to help us improve. / 提交报告帮助我们改进。
description: 提交报告帮助我们改进。
labels: [ 'bug' ]
body:
- type: markdown
attributes:
value: |
Thank you for taking the time to report this issue! Please describe your problem accurately. If possible, please provide a reproducible snippet (this will help resolve the issue more quickly). Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
- type: textarea
attributes:
label: What happened / 发生了什么
description: Description
label: 发生了什么
description: 描述你遇到的异常
placeholder: >
Please provide a clear and specific description of what this exception is. Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
validations:
required: true
- type: textarea
attributes:
label: Reproduce / 如何复现?
label: 如何复现?
description: >
The steps to reproduce the issue. / 复现该问题的步骤
复现该问题的步骤
placeholder: >
Example: 1. Open '...'
: 1. 打开 '...'
validations:
required: true
- type: textarea
attributes:
label: AstrBot version, deployment method (e.g., Windows Docker Desktop deployment), provider used, and messaging platform used. / AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
label: AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
description: >
请提供您的 AstrBot 版本和部署方式。
placeholder: >
Example: 4.5.7 Docker, 3.1.7 Windows Launcher
如: 3.1.8 Docker, 3.1.7 Windows启动器
validations:
required: true
- type: dropdown
attributes:
label: OS
label: 操作系统
description: |
On which operating system did you encounter this problem? / 你在哪个操作系统上遇到了这个问题?
你在哪个操作系统上遇到了这个问题?
multiple: false
options:
- 'Windows'
@@ -51,30 +53,30 @@ body:
- type: textarea
attributes:
label: Logs / 报错日志
label: 报错日志
description: >
Please provide complete Debug-level logs, such as error logs and screenshots. Don't worry if they're long! Please note that issues with insufficient details or no logs will be closed immediately. Thank you for your understanding. / 如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
placeholder: >
Please provide a complete error log or screenshot. / 请提供完整的报错日志或截图。
请提供完整的报错日志或截图。
validations:
required: true
- type: checkboxes
attributes:
label: Are you willing to submit a PR? / 你愿意提交 PR 吗?
label: 你愿意提交 PR 吗?
description: >
This is not required, but we would be happy to provide guidance during the contribution process, especially if you already have a good understanding of how to implement the fix. / 这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
options:
- label: Yes!
- label: 是的,我愿意提交 PR!
- type: checkboxes
attributes:
label: Code of Conduct
options:
- label: >
I have read and agree to abide by the project's [Code of Conduct](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true
- type: markdown
attributes:
value: "Thank you for filling out our form! / 感谢您填写我们的表单!"
value: "感谢您填写我们的表单!"

View File

@@ -1,25 +1,44 @@
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX issue, adds YY feature)-->
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX issue添加了 YY 功能)-->
<!-- 如果有的话,请指定此 PR 旨在解决的 ISSUE 编号。 -->
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
fixes #XYZ
---
### Motivation / 动机
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
### Modifications / 改动点
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
- [x] This is NOT a breaking change. / 这不是一个破坏性变更。
<!-- If your changes is a breaking change, please uncheck the checkbox above -->
### Verification Steps / 验证步骤
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤例如1. 导航到... 2. 点击...)。-->
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
### Screenshots or Test Results / 运行截图或测试结果
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
### Compatibility & Breaking Changes / 兼容性与破坏性变更
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
---
### Checklist / 检查清单
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.

View File

@@ -1,63 +1,63 @@
# AstrBot Development Instructions
# AstrBot 开发指南
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
AstrBot 是一个使用 Python 编写、配备 Vue.js 仪表盘的多平台 LLM 聊天机器人开发框架。它支持多个消息平台QQ、TelegramDiscord 等)和多种 LLM 提供商(OpenAIAnthropicGoogle Gemini 等)。
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
始终优先参考这些指南,仅在遇到与此处信息不符的意外情况时才回退到搜索或 bash 命令。
## Working Effectively
## 高效工作
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### 引导和安装依赖
- **需要 Python 3.10+** - 检查 `.python-version` 文件
- 安装 UV 包管理器:`pip install uv`
- 安装项目依赖:`uv sync` -- 很快几分钟。绝不要取消。设置超时时间为 10+ 分钟。
- 创建必需的目录:`mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### 运行应用程序
- 运行主应用程序:`uv run main.py` -- 约 3 秒启动
- 应用程序在 http://localhost:6185 创建 WebUI默认凭据`astrbot`/`astrbot`
- 应用程序自动从 `packages/` `data/plugins/` 目录加载插件
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### 仪表盘构建(Vue.js/Node.js
- **前置要求**:需要 Node.js 20+ npm 10+
- 导航到仪表盘:`cd dashboard`
- 安装仪表盘依赖:`npm install` -- 需要 2-3 分钟。绝不要取消。设置超时时间为 5+ 分钟。
- 构建仪表盘:`npm run build` -- 需要 25-30 秒。绝不要取消。
- 仪表盘在 `dashboard/dist/` 创建优化的生产构建
### Testing
- Do not generate test files for now.
### 测试
- 暂时不要生成测试文件。
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### 代码质量和检查
- 安装 ruff 检查器:`uv add --dev ruff`
- 检查代码风格:`uv run ruff check .` -- 耗时 <1
- 检查格式`uv run ruff format --check .` -- 耗时 <1
- 修复格式`uv run ruff format .`
- **始终**在提交更改前运行 `uv run ruff check .` `uv run ruff format .`
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### 插件开发
- 插件从 `packages/`内置 `data/plugins/`用户安装加载
- 插件系统支持函数工具和消息处理器
- 关键插件python_interpreterweb_searcherastrbotremindersession_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
### 常见问题和解决方法
- **仪表盘下载失败**已知的"除以零"错误问题 - 应用程序仍可正常工作
- **测试中的导入错误**确保使用 `uv run` 在适当的环境中运行测试
- **构建超时**始终设置适当的超时时间uv sync 10+ 分钟npm install 5+ 分钟
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## CI/CD 集成
- GitHub Actions 工作流在 `.github/workflows/`
- 通过 `Dockerfile` 支持 Docker 构建
- Pre-commit 钩子强制执行 ruff 格式化和检查
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Docker 支持
- 主要部署方法`docker run soulter/astrbot:latest`
- 可用的 Compose 文件`compose.yml`
- 暴露端口6185WebUI)、6195WeChat)、6199QQ
- 需要挂载卷`./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
## 多语言支持
- 文档包括中文README.md)、英文README_en.md)、日文README_ja.md
- UI 支持国际化
- 默认语言为中文
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
请记住这是一个有真实用户的生产聊天机器人框架始终进行彻底测试确保更改不会破坏现有功能

View File

@@ -13,7 +13,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v5
- name: Dashboard Build
run: |
@@ -70,7 +70,7 @@ jobs:
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6

View File

@@ -12,7 +12,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6

View File

@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v5
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
uses: actions/checkout@v5
with:
fetch-depth: 0

View File

@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v5
- name: Setup Node.js
uses: actions/setup-node@v6

View File

@@ -3,125 +3,18 @@ name: Docker Image CI/CD
on:
push:
tags:
- "v*"
schedule:
# Run at 00:00 UTC every day
- cron: "0 0 * * *"
- 'v*'
workflow_dispatch:
jobs:
build-nightly-image:
if: github.event_name == 'schedule'
publish-docker:
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
GHCR_OWNER: soulter
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Pull The Codes
uses: actions/checkout@v5
with:
fetch-depth: 1
fetch-tag: true
- name: Check for new commits today
if: github.event_name == 'schedule'
id: check-commits
run: |
# Get commits from the last 24 hours
commits=$(git log --since="24 hours ago" --oneline)
if [ -z "$commits" ]; then
echo "No commits in the last 24 hours, skipping build"
echo "has_commits=false" >> $GITHUB_OUTPUT
else
echo "Found commits in the last 24 hours:"
echo "$commits"
echo "has_commits=true" >> $GITHUB_OUTPUT
fi
- name: Exit if no commits
if: github.event_name == 'schedule' && steps.check-commits.outputs.has_commits == 'false'
run: exit 0
- name: Build Dashboard
run: |
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
- name: Determine test image tags
id: test-meta
run: |
short_sha=$(echo "${GITHUB_SHA}" | cut -c1-12)
build_date=$(date +%Y%m%d)
echo "short_sha=$short_sha" >> $GITHUB_OUTPUT
echo "build_date=$build_date" >> $GITHUB_OUTPUT
- name: Set QEMU
uses: docker/setup-qemu-action@v3
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ env.GHCR_OWNER }}
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build nightly image tags list
id: test-tags
run: |
TAGS="${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-latest
${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
if [ "${{ env.HAS_GHCR_TOKEN }}" = "true" ]; then
TAGS="$TAGS
ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-latest
ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
fi
echo "tags<<EOF" >> $GITHUB_OUTPUT
echo "$TAGS" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Build and Push Nightly Image
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.test-tags.outputs.tags }}
- name: Post build notifications
run: echo "Test Docker image has been built and pushed successfully"
build-release-image:
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v'))
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
GHCR_OWNER: soulter
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 1
fetch-tag: true
fetch-depth: 0 # Must be 0 so we can fetch tags
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
@@ -134,22 +27,21 @@ jobs:
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Compute release metadata
id: release-meta
- name: Check if version is pre-release
id: check-prerelease
run: |
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
version="${GITHUB_REF#refs/tags/}"
version="${{ github.ref_name }}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
echo "Version $version marked as pre-release"
echo "Version $version is a pre-release, will not push latest tag"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
echo "Version $version marked as stable"
echo "Version $version is a stable release, will push latest tag"
fi
echo "version=$version" >> $GITHUB_OUTPUT
- name: Build Dashboard
run: |
@@ -175,24 +67,23 @@ jobs:
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ env.GHCR_OWNER }}
username: Soulter
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Release Image
- name: Build and Push Docker to DockerHub and Github GHCR
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ steps.release-meta.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', env.DOCKER_HUB_USERNAME) || '' }}
${{ steps.release-meta.outputs.is_prerelease == 'false' && env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:latest', env.GHCR_OWNER) || '' }}
${{ format('{0}/astrbot:{1}', env.DOCKER_HUB_USERNAME, steps.release-meta.outputs.version) }}
${{ env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:{1}', env.GHCR_OWNER, steps.release-meta.outputs.version) || '' }}
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications
run: echo "Release Docker image has been built and pushed successfully"
run: echo "Docker image has been built and pushed successfully"

1
.gitignore vendored
View File

@@ -34,7 +34,6 @@ dashboard/node_modules/
dashboard/dist/
package-lock.json
package.json
yarn.lock
# Operating System
**/.DS_Store

View File

@@ -12,21 +12,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
bash \
ffmpeg \
curl \
gnupg \
git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y curl gnupg \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y nodejs
RUN apt-get update && apt-get install -y curl gnupg && \
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv \
&& echo "3.11" > .python-version
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pilk --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]
CMD [ "python", "main.py" ]

35
Dockerfile_with_node Normal file
View File

@@ -0,0 +1,35 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]

120
README.md
View File

@@ -8,7 +8,7 @@
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
@@ -32,7 +32,7 @@
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手还是企业知识库AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
## 主要功能
@@ -42,7 +42,7 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
5. **WebUI**。可视化配置和管理机器人,功能齐全。
## 部署方式
## 部署方式
#### Docker 部署(推荐 🥳)
@@ -119,73 +119,83 @@ uv run main.py
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## 支持的消息平台
## 消息平台支持情况
**官方维护**
- QQ (官方平台 & OneBot)
- Telegram
- 企微应用 & 企微智能机器人
- 微信客服 & 微信公众号
- 飞书
- 钉钉
- Slack
- Discord
- Satori
- Misskey
- Whatsapp (将支持)
- LINE (将支持)
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方平台) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企微应用 | ✔ |
| 企微智能机器人 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
## 支持的模型服务
## ⚡ 提供商支持情况
**大模型服务**
- OpenAI 及兼容服务
- Anthropic
- Google Gemini
- Moonshot AI
- 智谱 AI
- DeepSeek
- Ollama (本地部署)
- LM Studio (本地部署)
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [小马算力](https://www.tokenpony.cn/3YPyf)
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**LLMOps 平台**
- Dify
- 阿里云百炼应用
- Coze
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
- OpenAI Whisper
- SenseVoice
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- 阿里云百炼 TTS
- Azure TTS
- Minimax TTS
- 火山引擎 TTS
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献
@@ -219,7 +229,7 @@ pre-commit install
## ⭐ Star History
> [!TIP]
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我们维护这个开源项目的动力 <3
<div align="center">

View File

@@ -1,233 +1,182 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
<br>
_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot)
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">Documentation</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
</div>
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
## Key Features
## Key Features
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
## Deployment Methods
> [!TIP]
> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
#### Docker Deployment (Recommended 🥳)
## Deployment
We recommend deploying AstrBot using Docker or Docker Compose.
#### Docker Deployment
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
#### BT-Panel Deployment
#### Windows Installer
AstrBot has partnered with BT-Panel and is now available in their marketplace.
Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
Please refer to the official documentation: [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html).
#### 1Panel Deployment
AstrBot has been officially listed on the 1Panel marketplace.
Please refer to the official documentation: [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html).
#### Deploy on RainYun
AstrBot has been officially listed on RainYun's cloud application platform with one-click deployment.
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### Deploy on Replit
Community-contributed deployment method.
#### Replit Deployment
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### Windows One-Click Installer
Please refer to the official documentation: [Deploy AstrBot with Windows One-Click Installer](https://astrbot.app/deploy/astrbot/windows.html).
#### CasaOS Deployment
Community-contributed deployment method.
Please refer to the official documentation: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html).
Community-contributed method.
See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
#### Manual Deployment
First, install uv:
See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
```bash
pip install uv
```
## ⚡ Platform Support
Install AstrBot via Git Clone:
| Platform | Status | Details | Message Types |
| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| Feishu | ✔ | Group chats | Text, Images |
| WeChat Open Platform | 🚧 | Planned | - |
| Discord | 🚧 | Planned | - |
| WhatsApp | 🚧 | Planned | - |
| Xiaomi Speakers | 🚧 | Planned | - |
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
## Provider Support Status
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
| Name | Support | Type | Notes |
|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
| Claude API | ✔ | Text Generation | |
| Google Gemini API | ✔ | Text Generation | |
| Dify | ✔ | LLMOps | |
| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
| OneAPI | ✔ | LLM Distribution | |
| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
| SenseVoice | ✔ | Speech-to-Text | Local deployment |
| OpenAI TTS API | ✔ | Text-to-Speech | |
| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
## 🌍 Community
### QQ Groups
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Developer Group: 975206796
### Telegram Group
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord Server
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## Supported Messaging Platforms
**Officially Maintained**
- QQ (Official Platform & OneBot)
- Telegram
- WeChat Work Application & WeChat Work Intelligent Bot
- WeChat Customer Service & WeChat Official Accounts
- Feishu (Lark)
- DingTalk
- Slack
- Discord
- Satori
- Misskey
- WhatsApp (Coming Soon)
- LINE (Coming Soon)
**Community Maintained**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
## Supported Model Services
**LLM Services**
- OpenAI and Compatible Services
- Anthropic
- Google Gemini
- Moonshot AI
- Zhipu AI
- DeepSeek
- Ollama (Self-hosted)
- LM Studio (Self-hosted)
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**LLMOps Platforms**
- Dify
- Alibaba Cloud Bailian Applications
- Coze
**Speech-to-Text Services**
- OpenAI Whisper
- SenseVoice
**Text-to-Speech Services**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- Alibaba Cloud Bailian TTS
- Azure TTS
- Minimax TTS
- Volcano Engine TTS
## ❤️ Contributing
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
### How to Contribute
You can contribute by reviewing issues or helping with pull request reviews. Any issues or PRs are welcome to encourage community participation. Of course, these are just suggestions—you can contribute in any way you like. For adding new features, please discuss through an Issue first.
### Development Environment
AstrBot uses `ruff` for code formatting and linting.
```bash
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## ❤️ Special Thanks
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
Additionally, the birth of this project would not have been possible without the help of the following open-source projects:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - The amazing cat framework
## ⭐ Star History
# 🦌 Roadmap
> [!TIP]
> If this project has helped you in your life or work, or if you're interested in its future development, please give the project a Star. It's the driving force behind maintaining this open-source project <3
> Suggestions welcome via Issues <3
<div align="center">
- [ ] Ensure feature parity across all platform adapters
- [ ] Optimize plugin APIs
- [ ] Add default TTS services (e.g., GPT-Sovits)
- [ ] Enhance chat features with persistent memory
- [ ] i18n Planning
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
## ❤️ Contributions
All Issues/PRs welcome! Simply submit your changes to this project :)
For major features, please discuss via Issues first.
## 🌟 Support
- Star this project!
- Support via [Afdian](https://afdian.com/a/soulter)
- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
## ✨ Demos
> [!NOTE]
> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ Multimodal Input, Web Search, Text-to-Image ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ Natural Language TODO Lists ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ Plugin System Showcase ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ Web Dashboard ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ Built-in Web Chat Interface ✨_
</div>
</details>
## ⭐ Star History
> [!TIP]
> If this project helps you, please give it a star <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=AstrBotDevs/AstrBot&type=Date)](https://star-history.com/#AstrBotDevs/AstrBot&Date)
</div>
## Disclaimer
1. Licensed under `AGPL-v3`.
2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
3. Users must comply with local laws and regulations.
<!-- ## ✨ ATRI [Beta]
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
2. Long-term memory
3. Meme understanding & responses
4. TTS integration
-->
_私は、高性能ですから!_

View File

@@ -1,233 +1,167 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
<br>
_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot)
<a href="https://astrbot.app/">ドキュメントを見る</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
</div>
<br>
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデルLLM接続機能を備えたチャットボットおよび開発フレームワークです。
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
</div>
## ✨ 主な機能
<br>
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://astrbot.app/">ドキュメント</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
</div>
> [!TIP]
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。
## ✨ 使用方法
## 主な機能
#### Docker デプロイ
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください
## デプロイ方法
#### Windows ワンクリックインストーラーのデプロイ
#### Docker デプロイ(推奨 🥳)
コンピュータに Python>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
Docker / Docker Compose を使用した AstrBot デプロイを推奨します。
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
#### 宝塔パネルデプロイ
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
公式ドキュメント [宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html) をご参照ください。
#### 1Panel デプロイ
AstrBot は 1Panel 公式により 1Panel パネルに公開されています。
公式ドキュメント [1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html) をご参照ください。
#### 雨云でのデプロイ
AstrBot は雨云公式によりクラウドアプリケーションプラットフォームに公開され、ワンクリックでデプロイ可能です。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### Replit でのデプロイ
コミュニティ貢献によるデプロイ方法。
#### Replit デプロイ
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### Windows ワンクリックインストーラーデプロイ
公式ドキュメント [Windows ワンクリックインストーラーを使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/windows.html) をご参照ください。
#### CasaOS デプロイ
コミュニティ貢献によるデプロイ方法。
コミュニティが提供するデプロイ方法です
公式ドキュメント [CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html) を参照ください。
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
#### 手動デプロイ
まず uv をインストールします:
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
```bash
pip install uv
```
## ⚡ メッセージプラットフォームのサポート状況
Git Clone で AstrBot をインストール:
| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
| -------- | ------- | ------- | ------ |
| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
| Feishu | ✔ | グループチャット | テキスト、画像 |
| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
| Discord | 🚧 | 計画中 | - |
| WhatsApp | 🚧 | 計画中 | - |
| Xiaoai 音響 | 🚧 | 計画中 | - |
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
# 🦌 今後のロードマップ
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
> [!TIP]
> Issue でさらに多くの提案を歓迎します <3
## 🌍 コミュニティ
- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
- [ ] プラグインインターフェースの最適化
- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
- [ ] i18n の計画
### QQ グループ
## ❤️ 貢献
- 1群:322154837
- 3群:630166526
- 5群:822130018
- 6群:753075035
- 開発者群:975206796
Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
### Telegram グループ
新機能の追加については、まず Issue で議論してください。
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## 🌟 サポート
### Discord サーバー
- このプロジェクトに Star を付けてください!
- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ✨ デモ
## サポートされているメッセージプラットフォーム
> [!NOTE]
> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
**公式メンテナンス**
<div align='center'>
- QQ (公式プラットフォーム & OneBot)
- Telegram
- WeChat Work アプリケーション & WeChat Work インテリジェントボット
- WeChat カスタマーサービス & WeChat 公式アカウント
- Feishu (Lark)
- DingTalk
- Slack
- Discord
- Satori
- Misskey
- WhatsApp (近日対応予定)
- LINE (近日対応予定)
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
**コミュニティメンテナンス**
_✨ Docker ベースのサンドボックス化されたコードエグゼキューターベータテスト中✨_
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
## サポートされているモデルサービス
_✨ 多モーダル、ウェブ検索、長文の画像変換設定可能✨_
**大規模言語モデルサービス**
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
- OpenAI および互換サービス
- Anthropic
- Google Gemini
- Moonshot AI
- 智谱 AI
- DeepSeek
- Ollama (セルフホスト)
- LM Studio (セルフホスト)
- [優云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [小馬算力](https://www.tokenpony.cn/3YPyf)
- [硅基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
_✨ 自然言語タスク ✨_
**LLMOps プラットフォーム**
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
- Dify
- Alibaba Cloud 百炼アプリケーション
- Coze
_✨ プラグインシステム - 一部のプラグインの展示 ✨_
**音声認識サービス**
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width="600">
- OpenAI Whisper
- SenseVoice
_✨ 管理パネル ✨_
**音声合成サービス**
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- Alibaba Cloud 百炼 TTS
- Azure TTS
- Minimax TTS
- Volcano Engine TTS
_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## ❤️ コントリビューション
Issue や Pull Request は大歓迎です!このプロジェクトに変更を送信してください :)
### コントリビュート方法
Issue を確認したり、PR(プルリクエスト)のレビューを手伝うことで貢献できます。どんな Issue や PR への参加も歓迎され、コミュニティ貢献を促進します。もちろん、これらは提案に過ぎず、どんな方法でも貢献できます。新機能の追加については、まず Issue で議論してください。
### 開発環境
AstrBot はコードのフォーマットとチェックに `ruff` を使用しています。
```bash
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## ❤️ Special Thanks
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
また、このプロジェクトの誕生は以下のオープンソースプロジェクトの助けなしには実現できませんでした:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 素晴らしい猫猫フレームワーク
</div>
## ⭐ Star History
> [!TIP]
> このプロジェクトがあなたの生活や仕事に役立ったり、このプロジェクトの今後の発展に関心がある場合は、プロジェクトに Star をください。これこのオープンソースプロジェクトを維持する原動力です <3
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これこのオープンソースプロジェクトを維持するためのモチベーションです <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
</details>
## スポンサー
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
<!-- ## ✨ ATRI [ベータテスト]
この機能はプラグインとしてロードされます。プラグインリポジトリのアドレス:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 《ATRI ~ My Dear Moments》の主人公 ATRI のキャラクターセリフを微調整データセットとして使用した `Qwen1.5-7B-Chat Lora` 微調整モデル。
2. 長期記憶
3. ミームの理解と返信
4. TTS
-->
_私は、高性能ですから!_

97
astrbot/__main__.py Normal file
View File

@@ -0,0 +1,97 @@
import argparse
import asyncio
import mimetypes
import os
import sys
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.config.default import VERSION
from astrbot.core.initial_loader import InitialLoader
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot_api import LOGO, IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
def check_env() -> None:
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
logger.error("请使用 Python3.10+ 运行本项目。")
exit()
# os.makedirs("data/config", exist_ok=True)
# os.makedirs("data/plugins", exist_ok=True)
# os.makedirs("data/temp", exist_ok=True)
# 针对问题 #181 的临时解决方案
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("text/javascript", ".mjs")
mimetypes.add_type("application/json", ".json")
async def check_dashboard_files(webui_dir: str | None = None):
"""下载管理面板文件"""
# 指定webui目录
if webui_dir:
if os.path.exists(webui_dir):
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
return webui_dir
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
data_dist_path = str(AstrbotPaths.astrbot_root / "dist")
if os.path.exists(data_dist_path):
v = await get_dashboard_version()
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI 版本已是最新。")
else:
logger.warning(
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
)
return data_dist_path
logger.info(
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip并将其中的 dist 文件夹解压至 data 目录下。",
)
try:
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return data_dist_path
def main():
parser = argparse.ArgumentParser(description="AstrBot")
parser.add_argument(
"--webui-dir",
type=str,
help="指定 WebUI 静态文件目录路径",
default=None,
)
args = parser.parse_args()
check_env()
# 启动日志代理
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# 检查仪表板文件
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
db = db_helper
# 打印 logo
logger.info(LOGO)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
asyncio.run(core_lifecycle.start())
if __name__ == "__main__":
main()

View File

@@ -36,8 +36,7 @@ from astrbot.core.star.config import *
# provider
from astrbot.core.provider import Provider, ProviderMetaData
from astrbot.core.db.po import Personality
from astrbot.core.provider import Provider, Personality, ProviderMetaData
# platform
from astrbot.core.platform import (

View File

@@ -1,5 +1,4 @@
from astrbot.core.db.po import Personality
from astrbot.core.provider import Provider, STTProvider
from astrbot.core.provider import Personality, Provider, STTProvider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMetaData,

View File

@@ -1 +1 @@
__version__ = "4.7.3"
"""AstrBot CLI入口"""

View File

@@ -1,27 +1,22 @@
"""AstrBot CLI入口"""
import sys
from importlib.metadata import version
import click
from . import __version__
from astrbot_api import LOGO
from .commands import conf, init, plug, run
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
__version__ = version("astrbot")
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo(LOGO)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")

View File

@@ -3,6 +3,11 @@ import shutil
from pathlib import Path
import click
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from ..utils import (
PluginStatus,
@@ -47,8 +52,7 @@ def display_plugins(plugins, title=None, color=None):
@click.argument("name")
def new(name: str):
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
plug_path = AstrbotPaths.getPaths(name).plugins
if plug_path.exists():
raise click.ClickException(f"插件 {name} 已存在")

View File

@@ -1,22 +1,31 @@
import warnings
from pathlib import Path
import click
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths = sync_base_container.get(type[IAstrbotPaths])
def check_astrbot_root(path: str | Path) -> bool:
"""检查路径是否为 AstrBot 根目录"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
return False
if not (path / ".astrbot").exists():
return False
return True
warnings.warn(
"请使用 AstrbotPaths 类代替本模块中的函数",
DeprecationWarning,
stacklevel=2,
)
return AstrbotPaths.is_root(Path(path))
def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径"""
return Path.cwd()
warnings.warn(
"请使用 AstrbotPaths 类代替本模块中的函数",
DeprecationWarning,
stacklevel=2,
)
return AstrbotPaths.astrbot_root
async def check_dashboard(astrbot_root: Path) -> None:

View File

@@ -9,10 +9,6 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from .log import LogBroker, LogManager # noqa
from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", False)

View File

@@ -4,14 +4,6 @@ from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
@@ -20,24 +12,21 @@ from .run_context import TContext
from .tool import FunctionTool
try:
import anyio
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
)
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
)
def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format"""
"""准备配置,处理嵌套格式"""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
@@ -46,7 +35,7 @@ def _prepare_config(config: dict) -> dict:
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity"""
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
@@ -61,7 +50,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP connection config missing transport or type field")
raise Exception("MCP 连接配置缺少 transport type 字段")
async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http":
@@ -102,7 +91,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
@@ -112,7 +101,6 @@ class MCPClient:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
self.name: str | None = None
self.active: bool = True
@@ -120,32 +108,22 @@ class MCPClient:
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
# Store connection config for reconnection
self._mcp_server_config: dict | None = None
self._server_name: str | None = None
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""Connect to MCP server
"""连接到 MCP 服务器
If `url` parameter exists:
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
2. When transport is specified as `sse`, use SSE connection.
3. If not specified, default to SSE connection to MCP service.
如果 `url` 参数存在:
1. transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
# Store config for reconnection
self._mcp_server_config = mcp_server_config
self._server_name = name
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# Handle MCP service error logs
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
@@ -159,7 +137,7 @@ class MCPClient:
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP connection config missing transport or type field")
raise Exception("MCP 连接配置缺少 transport type 字段")
if transport_type != "streamable_http":
# SSE transport method
@@ -215,7 +193,7 @@ class MCPClient:
)
def callback(msg: str):
# Handle MCP service error logs
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
@@ -244,120 +222,10 @@ class MCPClient:
self.tools = response.tools
return response
async def _reconnect(self) -> None:
"""Reconnect to the MCP server using the stored configuration.
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
Raises:
Exception: raised when reconnection fails
"""
async with self._reconnect_lock:
# Check if already reconnecting (useful for logging)
if self._reconnecting:
logger.debug(
f"MCP Client {self._server_name} is already reconnecting, skipping"
)
return
if not self._mcp_server_config or not self._server_name:
raise Exception("Cannot reconnect: missing connection configuration")
self._reconnecting = True
try:
logger.info(
f"Attempting to reconnect to MCP server {self._server_name}..."
)
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
if self.exit_stack:
self._old_exit_stacks.append(self.exit_stack)
# Mark old session as invalid
self.session = None
# Create new exit stack for new connection
self.exit_stack = AsyncExitStack()
# Reconnect using stored config
await self.connect_to_server(self._mcp_server_config, self._server_name)
await self.list_tools_and_save()
logger.info(
f"Successfully reconnected to MCP server {self._server_name}"
)
except Exception as e:
logger.error(
f"Failed to reconnect to MCP server {self._server_name}: {e}"
)
raise
finally:
self._reconnecting = False
async def call_tool_with_reconnect(
self,
tool_name: str,
arguments: dict,
read_timeout_seconds: timedelta,
) -> mcp.types.CallToolResult:
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
Args:
tool_name: tool name
arguments: tool arguments
read_timeout_seconds: read timeout
Returns:
MCP tool call result
Raises:
ValueError: MCP session is not available
anyio.ClosedResourceError: raised after reconnection failure
"""
@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def _call_with_retry():
if not self.session:
raise ValueError("MCP session is not available for MCP function tools.")
try:
return await self.session.call_tool(
name=tool_name,
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
logger.warning(
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
)
# Attempt to reconnect
await self._reconnect()
# Reraise the exception to trigger tenacity retry
raise
return await _call_with_retry()
async def cleanup(self):
"""Clean up resources including old exit stacks from reconnections"""
# Close current exit stack
try:
await self.exit_stack.aclose()
except Exception as e:
logger.debug(f"Error closing current exit stack: {e}")
# Don't close old exit stacks as they may be in different task contexts
# They will be garbage collected naturally
# Just clear the list to release references
self._old_exit_stacks.clear()
# Set running_event first to unblock any waiting tasks
self.running_event.set()
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
class MCPTool(FunctionTool, Generic[TContext]):
@@ -378,8 +246,14 @@ class MCPTool(FunctionTool, Generic[TContext]):
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> mcp.types.CallToolResult:
return await self.mcp_client.call_tool_with_reconnect(
tool_name=self.mcp_tool.name,
session = self.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=self.mcp_tool.name,
arguments=kwargs,
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
read_timeout_seconds=timedelta(
seconds=context.tool_call_timeout,
),
)
return res

View File

@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import core_schema
@@ -76,7 +76,7 @@ class ImageURLPart(ContentPart):
"""The ID of the image, to allow LLMs to distinguish different images."""
type: str = "image_url"
image_url: ImageURL
image_url: str
class AudioURLPart(ContentPart):
@@ -119,13 +119,6 @@ class ToolCall(BaseModel):
"""The ID of the tool call."""
function: FunctionBody
"""The function body of the tool call."""
extra_content: dict[str, Any] | None = None
"""Extra metadata for the tool call."""
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
if self.extra_content is None:
kwargs.setdefault("exclude", set()).add("extra_content")
return super().model_dump(**kwargs)
class ToolCallPart(BaseModel):
@@ -145,39 +138,22 @@ class Message(BaseModel):
"tool",
]
content: str | list[ContentPart] | None = None
content: str | list[ContentPart]
"""The content of the message."""
tool_calls: list[ToolCall] | list[dict] | None = None
"""The tool calls of the message."""
tool_call_id: str | None = None
"""The ID of the tool call."""
@model_validator(mode="after")
def check_content_required(self):
# assistant + tool_calls is not None: allow content to be None
if self.role == "assistant" and self.tool_calls is not None:
return self
# other all cases: content is required
if self.content is None:
raise ValueError(
"content is required unless role='assistant' and tool_calls is not None"
)
return self
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message):
"""A message segment representing a tool call."""
role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message):

View File

@@ -1,21 +1,16 @@
from dataclasses import dataclass
from typing import Any, Generic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypeVar
from .message import Message
TContext = TypeVar("TContext", default=Any)
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -2,12 +2,13 @@ import abc
import typing as T
from enum import Enum, auto
from astrbot import logger
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
class AgentState(Enum):
@@ -23,7 +24,9 @@ class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
@@ -37,13 +40,6 @@ class BaseAgentRunner(T.Generic[TContext]):
"""Process a single step of the agent."""
...
@abc.abstractmethod
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
...
@abc.abstractmethod
def done(self) -> bool:
"""Check if the agent has completed its task.
@@ -57,9 +53,3 @@ class BaseAgentRunner(T.Generic[TContext]):
This method should be called after the agent is done.
"""
...
def _transition_state(self, new_state: AgentState) -> None:
"""Transition the agent state."""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state

View File

@@ -1,367 +0,0 @@
import base64
import json
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .coze_api_client import CozeAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class CozeAgentRunner(BaseAgentRunner[TContext]):
"""Coze Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
# 会话相关缓存
self.file_id_cache: dict[str, dict[str, str]] = {}
@override
async def step(self):
"""
执行 Coze Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Coze 请求并处理结果
async for response in self._execute_coze_request():
yield response
except Exception as e:
logger.error(f"Coze 请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Coze 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _execute_coze_request(self):
"""执行 Coze 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
contexts = self.req.contexts or []
system_prompt = self.req.system_prompt
# 用户ID参数
user_id = session_id
# 获取或创建会话ID
conversation_id = await sp.get_async(
scope="umo",
scope_id=user_id,
key="coze_conversation_id",
default="",
)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{
"role": "system",
"content": system_prompt,
"content_type": "text",
},
)
# 处理历史上下文
if not self.auto_save_history and contexts:
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
# 处理上下文中的图片
content = ctx["content"]
if isinstance(content, list):
# 多模态内容,需要处理图片
processed_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片上传
try:
image_data = item.get("image_url", {})
url = image_data.get("url", "")
if url:
file_id = (
await self._download_and_upload_image(
url, session_id
)
)
processed_content.append(
{
"type": "file",
"file_id": file_id,
"file_url": url,
}
)
except Exception as e:
logger.warning(f"处理上下文图片失败: {e}")
continue
if processed_content:
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本内容
additional_messages.append(
{
"role": ctx["role"],
"content": content,
"content_type": "text",
}
)
# 构建当前消息
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
# the url is a base64 string
try:
image_data = base64.b64decode(url)
file_id = await self.api_client.upload_file(image_data)
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.warning(f"处理图片失败 {url}: {e}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
elif prompt:
# 纯文本
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
},
)
# 执行 Coze API 请求
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
await sp.put_async(
scope="umo",
scope_id=user_id,
key="coze_conversation_id",
value=data["conversation_id"],
)
if event_type == "conversation.message.delta":
# 增量消息
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
accumulated_content += content
message_started = True
# 如果是流式响应,发送增量数据
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(content)
),
)
elif event_type == "conversation.message.completed":
# 消息完成
logger.debug("Coze message completed")
message_started = True
elif event_type == "conversation.chat.completed":
# 对话完成
logger.debug("Coze chat completed")
break
elif event_type == "error":
# 错误处理
error_msg = data.get("msg", "未知错误")
error_code = data.get("code", "UNKNOWN")
logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
if not message_started and not accumulated_content:
logger.warning("Coze 未返回任何内容")
accumulated_content = ""
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def _download_and_upload_image(
self,
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
import hashlib
# 计算哈希实现缓存
cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
if session_id:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self.api_client.upload_file(image_data)
if session_id:
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
raise Exception(f"处理图片失败: {e!s}")
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -1,403 +0,0 @@
import asyncio
import functools
import queue
import re
import sys
import threading
import typing as T
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
"""Dashscope Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.variables: dict = provider_config.get("variables", {}) or {}
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
@override
async def step(self):
"""
执行 Dashscope Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Dashscope 请求并处理结果
async for response in self._execute_dashscope_request():
yield response
except Exception as e:
logger.error(f"阿里云百炼请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
),
)
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
def _consume_sync_generator(
self, response: T.Any, response_queue: queue.Queue
) -> None:
"""在线程中消费同步generator,将结果放入队列
Args:
response: 同步generator对象
response_queue: 用于传递数据的队列
"""
try:
if self.streaming:
for chunk in response:
response_queue.put(("data", chunk))
else:
response_queue.put(("data", response))
except Exception as e:
response_queue.put(("error", e))
finally:
response_queue.put(("done", None))
async def _process_stream_chunk(
self, chunk: ApplicationResponse, output_text: str
) -> tuple[str, list | None, AgentResponse | None]:
"""处理流式响应的单个chunk
Args:
chunk: Dashscope响应chunk
output_text: 当前累积的输出文本
Returns:
(更新后的output_text, doc_references, AgentResponse或None)
"""
logger.debug(f"dashscope stream chunk: {chunk}")
if chunk.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
)
self._transition_state(AgentState.ERROR)
error_msg = (
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
)
self.final_llm_resp = LLMResponse(
role="err",
result_chain=MessageChain().message(error_msg),
)
return (
output_text,
None,
AgentResponse(
type="err",
data=AgentResponseData(chain=MessageChain().message(error_msg)),
),
)
chunk_text = chunk.output.get("text", "") or ""
# RAG 引用脚标格式化
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
response = None
if chunk_text:
output_text += chunk_text
response = AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
)
# 获取文档引用
doc_references = chunk.output.get("doc_references", None)
return output_text, doc_references, response
def _format_doc_references(self, doc_references: list) -> str:
"""格式化文档引用为文本
Args:
doc_references: 文档引用列表
Returns:
格式化后的引用文本
"""
ref_parts = []
for ref in doc_references:
ref_title = (
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
)
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
ref_str = "".join(ref_parts)
return f"\n\n回答来源:\n{ref_str}"
async def _build_request_payload(
self, prompt: str, session_id: str, contexts: list, system_prompt: str
) -> dict:
"""构建请求payload
Args:
prompt: 用户输入
session_id: 会话ID
contexts: 上下文列表
system_prompt: 系统提示词
Returns:
请求payload字典
"""
conversation_id = await sp.get_async(
scope="umo",
scope_id=session_id,
key="dashscope_conversation_id",
default="",
)
# 获得会话变量
payload_vars = self.variables.copy()
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
p = {
"app_id": self.app_id,
"api_key": self.api_key,
"prompt": prompt,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if conversation_id:
p["session_id"] = conversation_id
return p
else:
# 不支持多轮对话的
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
return payload
async def _handle_streaming_response(
self, response: T.Any, session_id: str
) -> T.AsyncGenerator[AgentResponse, None]:
"""处理流式响应
Args:
response: Dashscope 流式响应 generator
Yields:
AgentResponse 对象
"""
response_queue = queue.Queue()
consumer_thread = threading.Thread(
target=self._consume_sync_generator,
args=(response, response_queue),
daemon=True,
)
consumer_thread.start()
output_text = ""
doc_references = None
while True:
try:
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
continue
if item_type == "done":
break
elif item_type == "error":
raise item_data
elif item_type == "data":
chunk = item_data
assert isinstance(chunk, ApplicationResponse)
(
output_text,
chunk_doc_refs,
response,
) = await self._process_stream_chunk(chunk, output_text)
if response:
if response.type == "err":
yield response
return
yield response
if chunk_doc_refs:
doc_references = chunk_doc_refs
if chunk.output.session_id:
await sp.put_async(
scope="umo",
scope_id=session_id,
key="dashscope_conversation_id",
value=chunk.output.session_id,
)
# 添加 RAG 引用
if self.output_reference and doc_references:
ref_text = self._format_doc_references(doc_references)
output_text += ref_text
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(ref_text)),
)
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(output_text)])
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def _execute_dashscope_request(self):
"""执行 Dashscope 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
contexts = self.req.contexts or []
system_prompt = self.req.system_prompt
# 检查图片输入
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
# 构建请求payload
payload = await self._build_request_payload(
prompt, session_id, contexts, system_prompt
)
if not self.streaming:
payload["incremental_output"] = False
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -1,336 +0,0 @@
import base64
import os
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .dify_api_client import DifyAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DifyAgentRunner(BaseAgentRunner[TContext]):
"""Dify Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("dify_api_key", "")
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_type = provider_config.get("dify_api_type", "chat")
self.workflow_output_key = provider_config.get(
"dify_workflow_output_key",
"astrbot_wf_output",
)
self.dify_query_input_key = provider_config.get(
"dify_query_input_key",
"astrbot_text_query",
)
self.variables: dict = provider_config.get("variables", {}) or {}
self.timeout = provider_config.get("timeout", 60)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.api_client = DifyAPIClient(self.api_key, self.api_base)
@override
async def step(self):
"""
执行 Dify Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Dify 请求并处理结果
async for response in self._execute_dify_request():
yield response
except Exception as e:
logger.error(f"Dify 请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Dify 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _execute_dify_request(self):
"""执行 Dify 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
system_prompt = self.req.system_prompt
conversation_id = await sp.get_async(
scope="umo",
scope_id=session_id,
key="dify_conversation_id",
default="",
)
result = ""
# 处理图片上传
files_payload = []
for image_url in image_urls:
# image_url is a base64 string
try:
image_data = base64.b64decode(image_url)
file_response = await self.api_client.file_upload(
file_data=image_data,
user=session_id,
mime_type="image/png",
file_name="image.png",
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
except Exception as e:
logger.warning(f"上传图片失败:{e}")
continue
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
# 处理不同的 API 类型
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk["event"] == "message" or chunk["event"] == "agent_message":
result += chunk["answer"]
if not conversation_id:
await sp.put_async(
scope="umo",
scope_id=session_id,
key="dify_conversation_id",
value=chunk["conversation_id"],
)
conversation_id = chunk["conversation_id"]
# 如果是流式响应,发送增量数据
if self.streaming and chunk["answer"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(chunk["answer"])
),
)
elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
raise Exception(
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
)
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify workflow resp chunk: {chunk}")
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
)
case "text_chunk":
if self.streaming and chunk["data"]["text"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(
chunk["data"]["text"]
)
),
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}"
)
raise Exception(
f"Dify 工作流出现错误:{chunk['data']['error']}"
)
if self.workflow_output_key not in chunk["data"]["outputs"]:
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
)
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
# 解析结果
chain = await self.parse_dify_result(result)
# 创建最终响应
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
"""解析 Dify 的响应结果"""
if isinstance(chunk, str):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))
elif isinstance(output, list):
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
for item in output:
# handle Array[File]
if (
not isinstance(item, dict)
or item.get("dify_model_identity", "") != "__dify__file__"
):
chains.append(Comp.Plain(str(output)))
break
else:
chains.append(Comp.Plain(str(output)))
# scan file
files = chunk["data"].get("files", [])
for item in files:
comp = await parse_file(item)
chains.append(comp)
return MessageChain(chain=chains)
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..message import AssistantMessageSegment, ToolCallMessageSegment
from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
@@ -55,19 +55,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.agent_hooks = agent_hooks
self.run_context = run_context
messages = []
# append existing messages in the run context
for msg in request.contexts:
messages.append(Message.model_validate(msg))
if request.prompt is not None:
m = await request.assemble_context()
messages.append(Message.model_validate(m))
if request.system_prompt:
messages.insert(
0,
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
@@ -104,22 +96,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
elif llm_response.completion_text:
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text),
),
)
elif llm_response.reasoning_content:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain(type="reasoning").message(
llm_response.reasoning_content,
),
),
)
continue
llm_resp_result = llm_response
break # got final response
@@ -147,13 +130,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# record the final assistant message
self.run_context.messages.append(
Message(
role="assistant",
content=llm_resp.completion_text or "",
),
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
@@ -180,16 +156,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain(type="tool_call").message(
f"🔨 调用工具: {tool_call_name}"
),
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
result.type = "tool_call_result"
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
@@ -202,23 +175,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
tool_calls_result=tool_call_result_blocks,
)
# record the assistant message with tool calls
self.run_context.messages.extend(
tool_calls_result.to_openai_messages_model()
)
self.req.append_tool_calls_result(tool_calls_result)
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
step_count = 0
while not self.done() and step_count < max_step:
step_count += 1
async for resp in self.step():
yield resp
async def _handle_function_tools(
self,
req: ProviderRequest,

View File

@@ -4,13 +4,12 @@ from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated
from pydantic import Field, model_validator
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
ToolExecResult = str | mcp.types.CallToolResult
@dataclass
@@ -56,14 +55,23 @@ class FunctionTool(ToolSchema, Generic[TContext]):
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
def __dict__(self) -> dict[str, Any]:
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
}
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> str | mcp.types.CallToolResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
@dataclass
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -71,7 +79,8 @@ class ToolSet:
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
tools: list[FunctionTool] = Field(default_factory=list)
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""

View File

@@ -1,19 +1,14 @@
from pydantic import Field
from pydantic.dataclasses import dataclass
from dataclasses import dataclass
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.context import Context
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class AstrAgentContext:
context: Context
"""The star context instance"""
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
event: AstrMessageEvent
"""The message event associated with the agent context."""
extra: dict[str, str] = Field(default_factory=dict)
"""Customized extra data."""
AgentContextWrapper = ContextWrapper[AstrAgentContext]

View File

@@ -1,36 +0,0 @@
from typing import Any
from mcp.types import CallToolResult
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.pipeline.context_utils import call_event_hook
from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
pass
MAIN_AGENT_HOOKS = MainAgentHooks()

View File

@@ -1,80 +0,0 @@
import traceback
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
stream_to_general: bool = False,
show_reasoning: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use:
await astr_event.send(resp.data["chain"])
continue
if stream_to_general and resp.type == "streaming_delta":
continue
if stream_to_general or not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
chain = resp.data["chain"]
if chain.type == "reasoning" and not show_reasoning:
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return

View File

@@ -1,246 +0,0 @@
import asyncio
import inspect
import traceback
import typing as T
import mcp
from astrbot import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
MessageEventResult,
)
from astrbot.core.provider.register import llm_tools
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input")
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
ctx = run_context.context.context
event = run_context.context.event
umo = event.unified_msg_origin
prov_id = await ctx.get_current_chat_provider_id(umo)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
prompt=input_,
system_prompt=tool.agent.instructions,
tools=toolset,
max_steps=30,
run_hooks=tool.agent.run_hooks,
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
)
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -3,11 +3,15 @@ import json
import logging
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
ASTRBOT_CONFIG_PATH = str(AstrbotPaths.astrbot_root / "cmd_config.json")
logger = logging.getLogger("astrbot")

View File

@@ -1,11 +1,17 @@
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
import os
from importlib.metadata import version
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
VERSION = "4.7.3"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
# 警告,请使用version函数获取版本,此变量兼容保留
VERSION = version("astrbot")
DB_PATH = str(AstrbotPaths.astrbot_root / "data_v4.db")
# 默认配置
DEFAULT_CONFIG = {
@@ -68,11 +74,7 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"agent_runner_type": "local",
"dify_agent_runner_provider_id": "",
"coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting",
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
@@ -90,7 +92,6 @@ DEFAULT_CONFIG = {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_provider_id": "",
"active_reply": {
"enable": False,
"method": "possibility_reply",
@@ -142,20 +143,10 @@ DEFAULT_CONFIG = {
"kb_names": [], # 默认知识库名称列表
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
"kb_agentic_mode": False,
}
"""
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
1. 保存配置时,配置项的类型验证
2. WebUI 展示提供商和平台适配器模版
WebUI 的配置文件在 `CONFIG_METADATA_3` 中。
未来将会逐步淘汰此配置元数据。
"""
# 配置项的中文描述、值类型
CONFIG_METADATA_2 = {
"platform_group": {
"metadata": {
@@ -648,7 +639,7 @@ CONFIG_METADATA_2 = {
},
"words_count_threshold": {
"type": "int",
"hint": "分段回复的字数上限。只有字数小于此值的消息会被分段,超过此值的长消息将直接发送(不分段)。默认为 150",
"hint": "超过这个字数的消息会被分段回复。默认为 150",
},
"regex": {
"type": "string",
@@ -755,7 +746,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
@@ -771,7 +761,6 @@ CONFIG_METADATA_2 = {
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -785,7 +774,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"xai_native_search": False,
"modalities": ["text", "image", "tool_use"],
@@ -817,7 +805,6 @@ CONFIG_METADATA_2 = {
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -832,7 +819,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "llama-3.1-8b",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -849,7 +835,6 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -891,24 +876,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"Groq": {
"id": "groq_default",
"provider": "groq",
"type": "groq_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.groq.com/openai/v1",
"timeout": 120,
"model_config": {
"model": "openai/gpt-oss-20b",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
@@ -922,7 +889,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -939,7 +905,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -956,7 +921,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
},
"小马算力": {
@@ -972,7 +936,6 @@ CONFIG_METADATA_2 = {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_headers": {},
"custom_extra_body": {},
},
"优云智算": {
@@ -987,7 +950,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -1001,7 +963,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -1017,15 +978,13 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "glm-4-flash",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Dify": {
"id": "dify_app_default",
"provider": "dify",
"type": "dify",
"provider_type": "agent_runner",
"provider_type": "chat_completion",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
@@ -1039,20 +998,20 @@ CONFIG_METADATA_2 = {
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "agent_runner",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
# "auto_save_history": True,
"auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
"type": "dashscope",
"provider_type": "agent_runner",
"provider_type": "chat_completion",
"enable": True,
"dashscope_app_type": "agent",
"dashscope_api_key": "",
@@ -1075,7 +1034,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
@@ -1088,7 +1046,6 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60,
"custom_headers": {},
"custom_extra_body": {},
},
"Whisper(API)": {
@@ -1101,7 +1058,7 @@ CONFIG_METADATA_2 = {
"api_base": "",
"model": "whisper-1",
},
"Whisper(Local)": {
"Whisper(本地加载)": {
"hint": "启用前请 pip 安装 openai-whisper 库N卡用户大约下载 2GB主要是 torch 和 cudaCPU 用户大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"provider": "openai",
"type": "openai_whisper_selfhost",
@@ -1110,7 +1067,7 @@ CONFIG_METADATA_2 = {
"id": "whisper_selfhost",
"model": "tiny",
},
"SenseVoice(Local)": {
"SenseVoice(本地加载)": {
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库默认使用CPU大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"type": "sensevoice_stt_selfhost",
"provider": "sensevoice",
@@ -1145,7 +1102,7 @@ CONFIG_METADATA_2 = {
"pitch": "+0Hz",
"timeout": 20,
},
"GSV TTS(Local)": {
"GSV TTS(本地加载)": {
"id": "gsv_tts",
"enable": False,
"provider": "gpt_sovits",
@@ -1322,19 +1279,6 @@ CONFIG_METADATA_2 = {
"timeout": 20,
"launch_model_if_not_running": False,
},
"阿里云百炼重排序": {
"id": "bailian_rerank",
"type": "bailian_rerank",
"provider": "bailian",
"provider_type": "rerank",
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
"rerank_model": "qwen3-rerank",
"timeout": 30,
"return_documents": False,
"instruct": "",
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
@@ -1369,16 +1313,6 @@ CONFIG_METADATA_2 = {
"description": "重排序模型名称",
"type": "string",
},
"return_documents": {
"description": "是否在排序结果中返回文档原文",
"type": "bool",
"hint": "默认值false以减少网络传输开销。",
},
"instruct": {
"description": "自定义排序任务类型说明",
"type": "string",
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
},
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
@@ -1393,12 +1327,6 @@ CONFIG_METADATA_2 = {
"render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
},
"custom_headers": {
"description": "自定义添加请求头",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
},
"custom_extra_body": {
"description": "自定义请求体参数",
"type": "dict",
@@ -1921,6 +1849,7 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用",
"type": "bool",
"hint": "是否启用。",
},
"key": {
"description": "API Key",
@@ -2047,25 +1976,15 @@ CONFIG_METADATA_2 = {
"show_tool_use_status": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
},
"agent_runner_type": {
"type": "string",
},
"dify_agent_runner_provider_id": {
"type": "string",
},
"coze_agent_runner_provider_id": {
"type": "string",
},
"dashscope_agent_runner_provider_id": {
"type": "string",
"streaming_segmented": {
"type": "bool",
},
"max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
},
@@ -2110,9 +2029,6 @@ CONFIG_METADATA_2 = {
"image_caption": {
"type": "bool",
},
"image_caption_provider_id": {
"type": "string",
},
"image_caption_prompt": {
"type": "string",
},
@@ -2196,93 +2112,39 @@ CONFIG_METADATA_2 = {
"kb_names": {"type": "list", "items": {"type": "string"}},
"kb_fusion_top_k": {"type": "int", "default": 20},
"kb_final_top_k": {"type": "int", "default": 5},
"kb_agentic_mode": {"type": "bool"},
},
},
}
"""
v4.7.0 之后name, description, hint 等字段已经实现 i18n 国际化。国际化资源文件位于:
- dashboard/src/i18n/locales/en-US/features/config-metadata.json
- dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
如果在此文件中添加了新的配置字段,请务必同步更新上述两个国际化资源文件。
"""
CONFIG_METADATA_3 = {
"ai_group": {
"name": "AI 配置",
"metadata": {
"agent_runner": {
"description": "Agent 执行方式",
"hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。",
"ai": {
"description": "模型",
"type": "object",
"items": {
"provider_settings.enable": {
"description": "启用",
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "AI 对话总开关",
},
"provider_settings.agent_runner_type": {
"description": "执行器",
"type": "string",
"options": ["local", "dify", "coze", "dashscope"],
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"],
"condition": {
"provider_settings.enable": True,
},
},
"provider_settings.coze_agent_runner_provider_id": {
"description": "Coze Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:coze",
"condition": {
"provider_settings.agent_runner_type": "coze",
"provider_settings.enable": True,
},
},
"provider_settings.dify_agent_runner_provider_id": {
"description": "Dify Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:dify",
"condition": {
"provider_settings.agent_runner_type": "dify",
"provider_settings.enable": True,
},
},
"provider_settings.dashscope_agent_runner_provider_id": {
"description": "阿里云百炼应用 Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:dashscope",
"condition": {
"provider_settings.agent_runner_type": "dashscope",
"provider_settings.enable": True,
},
},
},
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
"description": "默认聊天模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
"hint": "留空时使用第一个模型",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "留空代表不使用可用于非多模态模型",
"hint": "留空代表不使用可用于不支持视觉模态的聊天模型",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
"hint": "STT 总开关",
"hint": "STT 总开关",
},
"provider_stt_settings.provider_id": {
"description": "默认语音转文本模型",
@@ -2296,11 +2158,12 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
"hint": "TTS 总开关",
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
},
"provider_tts_settings.provider_id": {
"description": "默认文本转语音模型",
"type": "string",
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
@@ -2311,9 +2174,6 @@ CONFIG_METADATA_3 = {
"type": "text",
},
},
"condition": {
"provider_settings.enable": True,
},
},
"persona": {
"description": "人格",
@@ -2325,10 +2185,6 @@ CONFIG_METADATA_3 = {
"_special": "select_persona",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"knowledgebase": {
"description": "知识库",
@@ -2351,15 +2207,6 @@ CONFIG_METADATA_3 = {
"type": "int",
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
},
"kb_agentic_mode": {
"description": "Agentic 知识库检索",
"type": "bool",
"hint": "启用后,知识库检索将作为 LLM Tool由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"websearch": {
@@ -2397,10 +2244,6 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"others": {
"description": "其他配置",
@@ -2409,83 +2252,54 @@ CONFIG_METADATA_3 = {
"provider_settings.display_reasoning_text": {
"description": "显示思考内容",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.identifier": {
"description": "用户识别",
"type": "bool",
"hint": "启用后,会在提示词前包含用户 ID 信息。",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
"hint": "启用后,会在系统提示词中附带当前时间信息。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.show_tool_use_status": {
"description": "输出函数调用状态",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.streaming_response": {
"description": "流式输出",
"description": "流式回复",
"type": "bool",
},
"provider_settings.unsupported_streaming_strategy": {
"description": "不支持流式回复的平台",
"type": "string",
"options": ["realtime_segmenting", "turn_off"],
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
"labels": ["实时分段回复", "关闭流式回复"],
"condition": {
"provider_settings.streaming_response": True,
},
"provider_settings.streaming_segmented": {
"description": "不支持流式回复的平台采取分段输出",
"type": "bool",
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条-1 为不限制",
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat则需要 /chat 才会触发 LLM 请求",
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
},
"provider_settings.prompt_prefix": {
"description": "用户提示词",
@@ -2497,9 +2311,6 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
},
"condition": {
"provider_settings.enable": True,
},
},
},
},
@@ -2789,16 +2600,7 @@ CONFIG_METADATA_3 = {
"provider_ltm_settings.image_caption": {
"description": "自动理解图片",
"type": "bool",
"hint": "需要设置群聊图片转述模型。",
},
"provider_ltm_settings.image_caption_provider_id": {
"description": "群聊图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
"condition": {
"provider_ltm_settings.image_caption": True,
},
"hint": "需要设置默认图片转述模型。",
},
"provider_ltm_settings.active_reply.enable": {
"description": "主动回复",

View File

@@ -1,110 +0,0 @@
"""
配置元数据国际化工具
提供配置元数据的国际化键转换功能
"""
from typing import Any
class ConfigMetadataI18n:
"""配置元数据国际化转换器"""
@staticmethod
def _get_i18n_key(group: str, section: str, field: str, attr: str) -> str:
"""
生成国际化键
Args:
group: 配置组,如 'ai_group', 'platform_group'
section: 配置节,如 'agent_runner', 'general'
field: 字段名,如 'enable', 'default_provider'
attr: 属性类型,如 'description', 'hint', 'labels'
Returns:
国际化键,格式如: 'ai_group.agent_runner.enable.description'
"""
if field:
return f"{group}.{section}.{field}.{attr}"
else:
return f"{group}.{section}.{attr}"
@staticmethod
def convert_to_i18n_keys(metadata: dict[str, Any]) -> dict[str, Any]:
"""
将配置元数据转换为使用国际化键
Args:
metadata: 原始配置元数据字典
Returns:
使用国际化键的配置元数据字典
"""
result = {}
for group_key, group_data in metadata.items():
group_result = {
"name": f"{group_key}.name",
"metadata": {},
}
for section_key, section_data in group_data.get("metadata", {}).items():
section_result = {
"description": f"{group_key}.{section_key}.description",
"type": section_data.get("type"),
}
# 复制其他属性
for key in ["items", "condition", "_special", "invisible"]:
if key in section_data:
section_result[key] = section_data[key]
# 处理 hint
if "hint" in section_data:
section_result["hint"] = f"{group_key}.{section_key}.hint"
# 处理 items 中的字段
if "items" in section_data and isinstance(section_data["items"], dict):
items_result = {}
for field_key, field_data in section_data["items"].items():
# 处理嵌套的点号字段名(如 provider_settings.enable
field_name = field_key
field_result = {}
# 复制基本属性
for attr in [
"type",
"condition",
"_special",
"invisible",
"options",
]:
if attr in field_data:
field_result[attr] = field_data[attr]
# 转换文本属性为国际化键
if "description" in field_data:
field_result["description"] = (
f"{group_key}.{section_key}.{field_name}.description"
)
if "hint" in field_data:
field_result["hint"] = (
f"{group_key}.{section_key}.{field_name}.hint"
)
if "labels" in field_data:
field_result["labels"] = (
f"{group_key}.{section_key}.{field_name}.labels"
)
items_result[field_key] = field_result
section_result["items"] = items_result
group_result["metadata"][section_key] = section_result
result[group_key] = group_result
return result

View File

@@ -16,12 +16,12 @@ import time
import traceback
from asyncio import Queue
from astrbot.api import logger, sp
from astrbot.core import LogBroker
from astrbot.core import LogBroker, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
@@ -33,7 +33,6 @@ from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.migra_helper import migra
from . import astrbot_config, html_renderer
from .event_bus import EventBus
@@ -97,16 +96,11 @@ class AstrBotCoreLifecycle:
sp=sp,
)
# apply migration
# 4.5 to 4.6 migration for umop_config_router
try:
await migra(
self.db,
self.astrbot_config_mgr,
self.umop_config_router,
self.astrbot_config_mgr,
)
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
except Exception as e:
logger.error(f"AstrBot migration failed: {e!s}")
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列

View File

@@ -13,7 +13,6 @@ from astrbot.core.db.po import (
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
Stats,
@@ -184,7 +183,7 @@ class BaseDatabase(abc.ABC):
user_id: str,
offset_sec: int = 86400,
) -> None:
"""Delete platform message history records newer than the specified offset."""
"""Delete platform message history records older than the specified offset."""
...
@abc.abstractmethod
@@ -314,51 +313,3 @@ class BaseDatabase(abc.ABC):
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...
# ====
# Platform Session Management
# ====
@abc.abstractmethod
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> PlatformSession:
"""Create a new Platform session."""
...
@abc.abstractmethod
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
...
@abc.abstractmethod
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[PlatformSession]:
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
...
@abc.abstractmethod
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a Platform session's updated_at timestamp and optionally display_name."""
...
@abc.abstractmethod
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
...

View File

@@ -1,131 +0,0 @@
"""Migration script for WebChat sessions.
This migration creates PlatformSession from existing platform_message_history records.
Changes:
- Creates platform_sessions table
- Adds platform_id field (default: 'webchat')
- Adds display_name field
- Session_id format: {platform_id}_{uuid}
"""
from sqlalchemy import func, select
from sqlmodel import col
from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
async def migrate_webchat_session(db_helper: BaseDatabase):
"""Create PlatformSession records from platform_message_history.
This migration extracts all unique user_ids from platform_message_history
where platform_id='webchat' and creates corresponding PlatformSession records.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_webchat_session_1"
)
if migration_done:
return
logger.info("开始执行数据库迁移WebChat 会话迁移)...")
try:
async with db_helper.get_db() as session:
# 从 platform_message_history 创建 PlatformSession
query = (
select(
col(PlatformMessageHistory.user_id),
col(PlatformMessageHistory.sender_name),
func.min(PlatformMessageHistory.created_at).label("earliest"),
func.max(PlatformMessageHistory.updated_at).label("latest"),
)
.where(col(PlatformMessageHistory.platform_id) == "webchat")
.where(col(PlatformMessageHistory.sender_id) != "bot")
.group_by(col(PlatformMessageHistory.user_id))
)
result = await session.execute(query)
webchat_users = result.all()
if not webchat_users:
logger.info("没有找到需要迁移的 WebChat 数据")
await sp.put_async(
"global", "global", "migration_done_webchat_session_1", True
)
return
logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
# 检查已存在的会话
existing_query = select(col(PlatformSession.session_id))
existing_result = await session.execute(existing_query)
existing_session_ids = {row[0] for row in existing_result.fetchall()}
# 查询 Conversations 表中的 title用于设置 display_name
# 对于每个 user_id对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id}
user_ids_to_query = [
f"webchat:FriendMessage:webchat!astrbot!{user_id}"
for user_id, _, _, _ in webchat_users
]
conv_query = select(
col(ConversationV2.user_id), col(ConversationV2.title)
).where(col(ConversationV2.user_id).in_(user_ids_to_query))
conv_result = await session.execute(conv_query)
# 创建 user_id -> title 的映射字典
title_map = {
user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title
for user_id, title in conv_result.fetchall()
}
# 批量创建 PlatformSession 记录
sessions_to_add = []
skipped_count = 0
for user_id, sender_name, created_at, updated_at in webchat_users:
# user_id 就是 webchat_conv_id (session_id)
session_id = user_id
# sender_name 通常是 username但可能为 None
creator = sender_name if sender_name else "guest"
# 检查是否已经存在该会话
if session_id in existing_session_ids:
logger.debug(f"会话 {session_id} 已存在,跳过")
skipped_count += 1
continue
# 从 Conversations 表中获取 display_name
display_name = title_map.get(user_id)
# 创建新的 PlatformSession保留原有的时间戳
new_session = PlatformSession(
session_id=session_id,
platform_id="webchat",
creator=creator,
is_group=0,
created_at=created_at,
updated_at=updated_at,
display_name=display_name,
)
sessions_to_add.append(new_session)
# 批量插入
if sessions_to_add:
session.add_all(sessions_to_add)
await session.commit()
logger.info(
f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}",
)
else:
logger.info("没有新会话需要迁移")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_webchat_session_1", True)
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
raise

View File

@@ -3,7 +3,13 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TypedDict
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
from sqlmodel import (
JSON,
Field,
SQLModel,
Text,
UniqueConstraint,
)
class PlatformStat(SQLModel, table=True):
@@ -12,7 +18,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats" # type: ignore
__tablename__ = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
@@ -31,7 +37,7 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore
__tablename__ = "conversations"
inner_conversation_id: int = Field(
primary_key=True,
@@ -68,7 +74,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas" # type: ignore
__tablename__ = "personas"
id: int | None = Field(
primary_key=True,
@@ -98,7 +104,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore
__tablename__ = "preferences"
id: int | None = Field(
default=None,
@@ -134,7 +140,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages.
"""
__tablename__ = "platform_message_history" # type: ignore
__tablename__ = "platform_message_history"
id: int | None = Field(
primary_key=True,
@@ -155,55 +161,13 @@ class PlatformMessageHistory(SQLModel, table=True):
)
class PlatformSession(SQLModel, table=True):
"""Platform session table for managing user sessions across different platforms.
A session represents a chat window for a specific user on a specific platform.
Each session can have multiple conversations (对话) associated with it.
"""
__tablename__ = "platform_sessions" # type: ignore
inner_id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
session_id: str = Field(
max_length=100,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(default="webchat", nullable=False)
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
creator: str = Field(nullable=False)
"""Username of the session creator"""
display_name: str | None = Field(default=None, max_length=255)
"""Display name for the session"""
is_group: int = Field(default=0, nullable=False)
"""0 for private chat, 1 for group chat (not implemented yet)"""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"session_id",
name="uix_platform_session_id",
),
)
class Attachment(SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments" # type: ignore
__tablename__ = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True,

View File

@@ -1,7 +1,7 @@
import asyncio
import threading
import typing as T
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -12,7 +12,6 @@ from astrbot.core.db.po import (
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
SQLModel,
@@ -413,7 +412,7 @@ class SQLiteDatabase(BaseDatabase):
user_id,
offset_sec=86400,
):
"""Delete platform message history records newer than the specified offset."""
"""Delete platform message history records older than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -423,7 +422,7 @@ class SQLiteDatabase(BaseDatabase):
delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) >= cutoff_time,
col(PlatformMessageHistory.created_at) < cutoff_time,
),
)
@@ -710,101 +709,3 @@ class SQLiteDatabase(BaseDatabase):
t.start()
t.join()
return result
# ====
# Platform Session Management
# ====
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> PlatformSession:
"""Create a new Platform session."""
kwargs = {}
if session_id:
kwargs["session_id"] = session_id
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_session = PlatformSession(
creator=creator,
platform_id=platform_id,
display_name=display_name,
is_group=is_group,
**kwargs,
)
session.add(new_session)
await session.flush()
await session.refresh(new_session)
return new_session
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
PlatformSession.session_id == session_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[PlatformSession]:
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = select(PlatformSession).where(PlatformSession.creator == creator)
if platform_id:
query = query.where(PlatformSession.platform_id == platform_id)
query = (
query.order_by(desc(PlatformSession.updated_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a Platform session's updated_at timestamp and optionally display_name."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if display_name is not None:
values["display_name"] = display_name
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id) == session_id,
),
)

View File

@@ -1,9 +0,0 @@
from __future__ import annotations
class AstrBotError(Exception):
"""Base exception for all AstrBot errors."""
class ProviderNotFoundError(AstrBotError):
"""Raised when a specified provider is not found."""

View File

@@ -1,7 +1,4 @@
import asyncio
import json
import re
import time
import uuid
from pathlib import Path
@@ -11,98 +8,12 @@ from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
EmbeddingProvider,
RerankProvider,
)
from astrbot.core.provider.provider import (
Provider as LLMProvider,
)
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from .chunking.base import BaseChunker
from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .models import KBDocument, KBMedia, KnowledgeBase
from .parsers.url_parser import extract_text_from_url
from .parsers.util import select_parser
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
class RateLimiter:
"""一个简单的速率限制器"""
def __init__(self, max_rpm: int):
self.max_per_minute = max_rpm
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
self.last_call_time = 0
async def __aenter__(self):
if self.interval == 0:
return
now = time.monotonic()
elapsed = now - self.last_call_time
if elapsed < self.interval:
await asyncio.sleep(self.interval - elapsed)
self.last_call_time = time.monotonic()
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def _repair_and_translate_chunk_with_retry(
chunk: str,
repair_llm_service: LLMProvider,
rate_limiter: RateLimiter,
max_retries: int = 2,
) -> list[str]:
"""
Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting.
"""
# 为了防止 LLM 上下文污染,在 user_prompt 中也加入明确的指令
user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided.
Text chunk to process:
---
{chunk}
---
"""
for attempt in range(max_retries + 1):
try:
async with rate_limiter:
response = await repair_llm_service.text_chat(
prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT
)
llm_output = response.completion_text
if "<discard_chunk />" in llm_output:
return [] # Signal to discard this chunk
# More robust regex to handle potential LLM formatting errors (spaces, newlines in tags)
matches = re.findall(
r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>",
llm_output,
re.DOTALL,
)
if matches:
# Further cleaning to ensure no empty strings are returned
return [m.strip() for m in matches if m.strip()]
else:
# If no valid tags and not explicitly discarded, discard it to be safe.
return []
except Exception as e:
logger.warning(
f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}"
)
logger.error(
f" - Failed to process chunk after {max_retries + 1} attempts. Using original text."
)
return [chunk]
class KBHelper:
@@ -189,7 +100,7 @@ class KBHelper:
async def upload_document(
self,
file_name: str,
file_content: bytes | None,
file_content: bytes,
file_type: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
@@ -197,7 +108,6 @@ class KBHelper:
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
pre_chunked_text: list[str] | None = None,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
@@ -220,63 +130,46 @@ class KBHelper:
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
file_size = 0
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
# async with aiofiles.open(file_path, "wb") as f:
# await f.write(file_content)
try:
chunks_text = []
# 阶段1: 解析文档
if progress_callback:
await progress_callback("parsing", 0, 100)
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if progress_callback:
await progress_callback("parsing", 100, 100)
# 保存媒体文件
saved_media = []
if pre_chunked_text is not None:
# 如果提供了预分块文本,直接使用
chunks_text = pre_chunked_text
file_size = sum(len(chunk) for chunk in chunks_text)
logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
else:
# 否则,执行标准的文件解析和分块流程
if file_content is None:
raise ValueError(
"当未提供 pre_chunked_text 时file_content 不能为空。"
)
file_size = len(file_content)
# 阶段1: 解析文档
if progress_callback:
await progress_callback("parsing", 0, 100)
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if progress_callback:
await progress_callback("parsing", 100, 100)
# 保存媒体文件
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 阶段2: 分块
if progress_callback:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 阶段2: 分块
if progress_callback:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
@@ -312,7 +205,7 @@ class KBHelper:
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=file_size,
file_size=len(file_content),
# file_path=str(file_path),
file_path="",
chunk_count=len(chunks_text),
@@ -466,177 +359,3 @@ class KBHelper:
)
return media
async def upload_from_url(
self,
url: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
enable_cleaning: bool = False,
cleaning_provider_id: str | None = None,
) -> KBDocument:
"""从 URL 上传并处理文档(带原子性保证和失败清理)
Args:
url: 要提取内容的网页 URL
chunk_size: 文本块大小
chunk_overlap: 文本块重叠大小
batch_size: 批处理大小
tasks_limit: 并发任务限制
max_retries: 最大重试次数
progress_callback: 进度回调函数,接收参数 (stage, current, total)
- stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding')
- current: 当前进度
- total: 总数
Returns:
KBDocument: 上传的文档对象
Raises:
ValueError: 如果 URL 为空或无法提取内容
IOError: 如果网络请求失败
"""
# 获取 Tavily API 密钥
config = self.prov_mgr.acm.default_conf
tavily_keys = config.get("provider_settings", {}).get(
"websearch_tavily_key", []
)
if not tavily_keys:
raise ValueError(
"Error: Tavily API key is not configured in provider_settings."
)
# 阶段1: 从 URL 提取内容
if progress_callback:
await progress_callback("extracting", 0, 100)
try:
text_content = await extract_text_from_url(url, tavily_keys)
except Exception as e:
logger.error(f"Failed to extract content from URL {url}: {e}")
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
if not text_content:
raise ValueError(f"No content extracted from URL: {url}")
if progress_callback:
await progress_callback("extracting", 100, 100)
# 阶段2: (可选)清洗内容并分块
final_chunks = await self._clean_and_rechunk_content(
content=text_content,
url=url,
progress_callback=progress_callback,
enable_cleaning=enable_cleaning,
cleaning_provider_id=cleaning_provider_id,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if enable_cleaning and not final_chunks:
raise ValueError(
"内容清洗后未提取到有效文本。请尝试关闭内容清洗功能或更换更高性能的LLM模型后重试。"
)
# 创建一个虚拟文件名
file_name = url.split("/")[-1] or f"document_from_{url}"
if not Path(file_name).suffix:
file_name += ".url"
# 复用现有的 upload_document 方法,但传入预分块文本
return await self.upload_document(
file_name=file_name,
file_content=None,
file_type="url", # 使用 'url' 作为特殊文件类型
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=final_chunks,
)
async def _clean_and_rechunk_content(
self,
content: str,
url: str,
progress_callback=None,
enable_cleaning: bool = False,
cleaning_provider_id: str | None = None,
repair_max_rpm: int = 60,
chunk_size: int = 512,
chunk_overlap: int = 50,
) -> list[str]:
"""
对从 URL 获取的内容进行清洗、修复、翻译和重新分块。
"""
if not enable_cleaning:
# 如果不启用清洗,则使用从前端传递的参数进行分块
logger.info(
f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}"
)
return await self.chunker.chunk(
content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if not cleaning_provider_id:
logger.warning(
"启用了内容清洗,但未提供 cleaning_provider_id跳过清洗并使用默认分块。"
)
return await self.chunker.chunk(content)
if progress_callback:
await progress_callback("cleaning", 0, 100)
try:
# 获取指定的 LLM Provider
llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id)
if not llm_provider or not isinstance(llm_provider, LLMProvider):
raise ValueError(
f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确"
)
# 初步分块
# 优化分隔符,优先按段落分割,以获得更高质量的文本块
text_splitter = RecursiveCharacterChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " "], # 优先使用段落分隔符
)
initial_chunks = await text_splitter.chunk(content)
logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。")
# 并发处理所有块
rate_limiter = RateLimiter(repair_max_rpm)
tasks = [
_repair_and_translate_chunk_with_retry(
chunk, llm_provider, rate_limiter
)
for chunk in initial_chunks
]
repaired_results = await asyncio.gather(*tasks, return_exceptions=True)
final_chunks = []
for i, result in enumerate(repaired_results):
if isinstance(result, Exception):
logger.warning(f"{i} 处理异常: {str(result)}. 回退到原始块。")
final_chunks.append(initial_chunks[i])
elif isinstance(result, list):
final_chunks.extend(result)
logger.info(
f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
)
if progress_callback:
await progress_callback("cleaning", 100, 100)
return final_chunks
except Exception as e:
logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}")
# 清洗失败,返回默认分块结果,保证流程不中断
return await self.chunker.chunk(content)

View File

@@ -8,7 +8,7 @@ from astrbot.core.provider.manager import ProviderManager
from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .kb_helper import KBHelper
from .models import KBDocument, KnowledgeBase
from .models import KnowledgeBase
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.rank_fusion import RankFusion
from .retrieval.sparse_retriever import SparseRetriever
@@ -284,47 +284,3 @@ class KnowledgeBaseManager:
await self.kb_db.close()
except Exception as e:
logger.error(f"关闭知识库元数据数据库失败: {e}")
async def upload_from_url(
self,
kb_id: str,
url: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> KBDocument:
"""从 URL 上传文档到指定的知识库
Args:
kb_id: 知识库 ID
url: 要提取内容的网页 URL
chunk_size: 文本块大小
chunk_overlap: 文本块重叠大小
batch_size: 批处理大小
tasks_limit: 并发任务限制
max_retries: 最大重试次数
progress_callback: 进度回调函数
Returns:
KBDocument: 上传的文档对象
Raises:
ValueError: 如果知识库不存在或 URL 为空
IOError: 如果网络请求失败
"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
raise ValueError(f"Knowledge base with id {kb_id} not found.")
return await kb_helper.upload_from_url(
url=url,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)

View File

@@ -1,103 +0,0 @@
import asyncio
import aiohttp
class URLExtractor:
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
def __init__(self, tavily_keys: list[str]):
"""
初始化 URL 提取器
Args:
tavily_keys: Tavily API 密钥列表
"""
if not tavily_keys:
raise ValueError("Error: Tavily API keys are not configured.")
self.tavily_keys = tavily_keys
self.tavily_key_index = 0
self.tavily_key_lock = asyncio.Lock()
async def _get_tavily_key(self) -> str:
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
async with self.tavily_key_lock:
key = self.tavily_keys[self.tavily_key_index]
self.tavily_key_index = (self.tavily_key_index + 1) % len(self.tavily_keys)
return key
async def extract_text_from_url(self, url: str) -> str:
"""
使用 Tavily API 从 URL 提取主要文本内容。
这是 web_searcher 插件中 tavily_extract_web_page 方法的简化版本,
专门为知识库模块设计,不依赖 AstrMessageEvent。
Args:
url: 要提取内容的网页 URL
Returns:
提取的文本内容
Raises:
ValueError: 如果 URL 为空或 API 密钥未配置
IOError: 如果请求失败或返回错误
"""
if not url:
raise ValueError("Error: url must be a non-empty string.")
tavily_key = await self._get_tavily_key()
api_url = "https://api.tavily.com/extract"
headers = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
payload = {
"urls": [url],
"extract_depth": "basic", # 使用基础提取深度
}
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
api_url,
json=payload,
headers=headers,
timeout=30.0, # 增加超时时间,因为内容提取可能需要更长时间
) as response:
if response.status != 200:
reason = await response.text()
raise OSError(
f"Tavily web extraction failed: {reason}, status: {response.status}"
)
data = await response.json()
results = data.get("results", [])
if not results:
raise ValueError(f"No content extracted from URL: {url}")
# 返回第一个结果的内容
return results[0].get("raw_content", "")
except aiohttp.ClientError as e:
raise OSError(f"Failed to fetch URL {url}: {e}") from e
except Exception as e:
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
# 为了向后兼容,提供一个简单的函数接口
async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
"""
简单的函数接口,用于从 URL 提取文本内容
Args:
url: 要提取内容的网页 URL
tavily_keys: Tavily API 密钥列表
Returns:
提取的文本内容
"""
extractor = URLExtractor(tavily_keys)
return await extractor.extract_text_from_url(url)

View File

@@ -1,65 +0,0 @@
TEXT_REPAIR_SYSTEM_PROMPT = """You are a meticulous digital archivist. Your mission is to reconstruct a clean, readable article from raw, noisy text chunks.
**Core Task:**
1. **Analyze:** Examine the text chunk to separate "signal" (substantive information) from "noise" (UI elements, ads, navigation, footers).
2. **Process:** Clean and repair the signal. **Do not translate it.** Keep the original language.
**Crucial Rules:**
- **NEVER discard a chunk if it contains ANY valuable information.** Your primary duty is to salvage content.
- **If a chunk contains multiple distinct topics, split them.** Enclose each topic in its own `<repaired_text>` tag.
- Your output MUST be ONLY `<repaired_text>...</repaired_text>` tags or a single `<discard_chunk />` tag.
---
**Example 1: Chunk with Noise and Signal**
*Input Chunk:*
"Home | About | Products | **The Llama is a domesticated South American camelid.** | © 2025 ACME Corp."
*Your Thought Process:*
1. "Home | About | Products..." and "© 2025 ACME Corp." are noise.
2. "The Llama is a domesticated..." is the signal.
3. I must extract the signal and wrap it.
*Your Output:*
<repaired_text>
The Llama is a domesticated South American camelid.
</repaired_text>
---
**Example 2: Chunk with ONLY Noise**
*Input Chunk:*
"Next Page > | Subscribe to our newsletter | Follow us on X"
*Your Thought Process:*
1. This entire chunk is noise. There is no signal.
2. I must discard this.
*Your Output:*
<discard_chunk />
---
**Example 3: Chunk with Multiple Topics (Requires Splitting)**
*Input Chunk:*
"## Chapter 1: The Sun
The Sun is the star at the center of the Solar System.
## Chapter 2: The Moon
The Moon is Earth's only natural satellite."
*Your Thought Process:*
1. This chunk contains two distinct topics.
2. I must process them separately to maintain semantic integrity.
3. I will create two `<repaired_text>` blocks.
*Your Output:*
<repaired_text>
## Chapter 1: The Sun
The Sun is the star at the center of the Solar System.
</repaired_text>
<repaired_text>
## Chapter 2: The Moon
The Moon is Earth's only natural satellite.
</repaired_text>
"""

View File

@@ -21,20 +21,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import asyncio
import base64
import json
import os
import uuid
from enum import Enum
from astrbot_api.abc import IAstrbotPaths
from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
class ComponentType(str, Enum):
# Basic Segment Types
Plain = "Plain" # plain text message
@@ -153,8 +153,7 @@ class Record(BaseMessageComponent):
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
file_path = str(AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.jpg")
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
@@ -242,8 +241,9 @@ class Video(BaseMessageComponent):
if url and url.startswith("file:///"):
return url[8:]
if url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
video_file_path = str(
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4().hex}"
)
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
@@ -442,8 +442,9 @@ class Image(BaseMessageComponent):
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
image_file_path = str(
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.jpg"
)
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
@@ -527,7 +528,7 @@ class Reply(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
type = ComponentType.Poke
id: int | None = 0
qq: int | None = 0
@@ -654,33 +655,19 @@ class File(BaseMessageComponent):
@property
def file(self) -> str:
"""获取文件路径如果文件不存在但有URL则同步下载文件
"""获取本地文件路径(仅返回已存在的文件
⚠️ 警告:此属性不会自动下载文件!
- 如果文件已存在,返回绝对路径
- 如果只有 URL 没有本地文件,返回空字符串
- 需要下载文件请使用 `await get_file()` 方法
Returns:
str: 文件路径
str: 文件的绝对路径,如果文件不存在则返回空字符串
"""
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@file.setter
@@ -714,15 +701,18 @@ class File(BaseMessageComponent):
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
if self.file_:
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
if not self.url:
raise ValueError("No URL provided for download")
download_dir = str(AstrbotPaths.astrbot_root / "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
file_path = str(AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4().hex}")
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_event_hook, call_handler
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
@dataclass
@@ -15,3 +15,4 @@ class PipelineContext:
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
call_local_llm_tool = call_local_llm_tool

View File

@@ -3,6 +3,8 @@ import traceback
import typing as T
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
@@ -105,3 +107,66 @@ async def call_event_hook(
return True
return event.is_stopped()
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
method_name: str,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
elif method_name == "call":
ready_to_call = handler(context, *args, **kwargs)
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -1,48 +0,0 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from ...context import PipelineContext
from ..stage import Stage
from .agent_sub_stages.internal import InternalAgentSubStage
from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
class AgentRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
for bwp in self.bot_wake_prefixs:
if self.prov_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
)
self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
if agent_runner_type == "local":
self.agent_sub_stage = InternalAgentSubStage()
else:
self.agent_sub_stage = ThirdPartyAgentSubStage()
await self.agent_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug(
"This pipeline does not enable AI capability, skip processing."
)
return
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
return
async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
yield resp

View File

@@ -1,464 +0,0 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
from .....astr_agent_context import AgentContextWrapper
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from .....astr_agent_run_util import AgentRunner, run_agent
from .....astr_agent_tool_exec import FunctionToolExecutor
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
class InternalAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_reasoning = settings.get("display_reasoning_text", False)
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
# 获取对话上下文
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def _apply_kb(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""Apply knowledge base context to the provider request"""
if not self.kb_agentic_mode:
if req.prompt is None:
return
try:
kb_result = await retrieve_knowledge_base(
query=req.prompt,
umo=event.unified_msg_origin,
context=self.ctx.plugin_manager.context,
)
if not kb_result:
return
if req.system_prompt is not None:
req.system_prompt += (
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
)
except Exception as e:
logger.error(f"Error occurred while retrieving knowledge base: {e}")
else:
if req.func_tool is None:
req.func_tool = ToolSet()
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
if len(contexts) // 2 <= self.max_context_length:
return contexts
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
return truncated_contexts
def _modalities_fix(
self,
provider: Provider,
req: ProviderRequest,
):
"""检查提供商的模态能力,清理请求中的不支持内容"""
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
)
req.func_tool = None
def _plugin_tool_fix(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""根据事件中的插件设置,过滤请求中的工具列表"""
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
async def _handle_webchat(
self,
event: AstrMessageEvent,
req: ProviderRequest,
prov: Provider,
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin,
req.conversation.cid,
)
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
latest_pair = messages[-2:]
if not latest_pair:
return
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
),
)
if llm_resp and llm_resp.completion_text:
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
unified_msg_origin=event.unified_msg_origin,
title=title,
conversation_id=req.conversation.cid,
)
async def _save_to_history(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
):
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
return
if not llm_response.completion_text and not req.tool_calls_result:
logger.debug("LLM 响应为空,不保存记录。")
return
if req.contexts is None:
req.contexts = []
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=messages,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
logger.debug("ready to request llm provider")
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
req.prompt = event.message_str[len(provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
# apply knowledge base feature
await self._apply_kb(event, req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# check provider modalities, if provider does not support image/tool_use, clear them in request.
self._modalities_fix(provider, req)
# filter tools, only keep tools from this pipeline's selected plugins
self._plugin_tool_fix(event, req)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
show_reasoning=self.show_reasoning,
),
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain()
.message(final_llm_resp.completion_text)
.chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
stream_to_general,
show_reasoning=self.show_reasoning,
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)

View File

@@ -1,205 +0,0 @@
import asyncio
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from astrbot.core import astrbot_config, logger
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
)
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
"coze": "coze_agent_runner_provider_id",
"dashscope": "dashscope_agent_runner_provider_id",
}
async def run_third_party_agent(
runner: "BaseAgentRunner",
stream_to_general: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""
运行第三方 agent runner 并转换响应格式
类似于 run_agent 函数,但专门处理第三方 agent runner
"""
try:
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
if resp.type == "streaming_delta":
if stream_to_general:
continue
yield resp.data["chain"]
elif resp.type == "llm_result":
if stream_to_general:
yield resp.data["chain"]
except Exception as e:
logger.error(f"Third party agent runner error: {e}")
err_msg = (
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
)
yield MessageChain().message(err_msg)
class ThirdPartyAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.conf = ctx.astrbot_config
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
self.prov_id = self.conf["provider_settings"].get(
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
"",
)
settings = ctx.astrbot_config["provider_settings"]
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
self.prov_cfg: dict = next(
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
{},
)
if not self.prov_id:
logger.error("没有填写 Agent Runner 提供商 ID请前往配置页面配置。")
return
if not self.prov_cfg:
logger.error(
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
)
return
# make provider request
req = ProviderRequest()
req.session_id = event.unified_msg_origin
req.prompt = event.message_str[len(provider_wake_prefix) :]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_base64()
req.image_urls.append(image_path)
if not req.prompt and not req.image_urls:
return
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if self.runner_type == "dify":
runner = DifyAgentRunner[AstrAgentContext]()
elif self.runner_type == "coze":
runner = CozeAgentRunner[AstrAgentContext]()
elif self.runner_type == "dashscope":
runner = DashscopeAgentRunner[AstrAgentContext]()
else:
raise ValueError(
f"Unsupported third party agent runner type: {self.runner_type}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
await runner.reset(
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=60,
),
agent_hooks=MAIN_AGENT_HOOKS,
provider_config=self.prov_cfg,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_third_party_agent(
runner,
stream_to_general=False,
),
),
)
yield
if runner.done():
final_resp = runner.get_final_llm_resp()
if final_resp and final_resp.result_chain:
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
# 非流式响应或转换为普通响应
async for _ in run_third_party_agent(
runner,
stream_to_general=stream_to_general,
):
yield
final_resp = runner.get_final_llm_resp()
if not final_resp or not final_resp.result_chain:
logger.warning("Agent Runner 未返回最终结果。")
return
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=self.runner_type,
provider_type=self.runner_type,
),
)

View File

@@ -0,0 +1,723 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
import traceback
from collections.abc import AsyncGenerator
from typing import Any
from mcp.types import CallToolResult
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
from ..stage import Stage
from ..utils import inject_kb_context
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
else:
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
event=run_context.context.event,
)
event = run_context.context.event
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=run_context.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
logger.debug(f"Found call in: {ty}")
is_override_call = True
break
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
awaitable = tool.handler
method_name = "decorator_handler"
elif is_override_call:
awaitable = tool.call
method_name = "call"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
method_name=method_name,
**tool_args,
)
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
res = await tool.call(run_context, **tool_args)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
llm_response,
)
async def on_tool_end(
self,
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
run_context.context.event.clear_result()
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
),
)
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
)
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
# 获取对话上下文
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
self,
event: AstrMessageEvent,
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin,
p_ctx=self.ctx,
req=req,
)
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
req.contexts = req.contexts[index:]
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# fix messages
req.contexts = self.fix_messages(req.contexts)
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
)
req.func_tool = None
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
)
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use),
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)
async def _handle_webchat(
self,
event: AstrMessageEvent,
req: ProviderRequest,
prov: Provider,
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin,
req.conversation.cid,
)
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
latest_pair = messages[-2:]
if not latest_pair:
return
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
),
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
)
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
unified_msg_origin=event.unified_msg_origin,
title=title,
conversation_id=req.conversation.cid,
)
async def _save_to_history(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
):
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
return
if not llm_response.completion_text and not req.tool_calls_result:
logger.debug("LLM 响应为空,不保存记录。")
return
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=messages,
)
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages

View File

@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> AsyncGenerator[None, None]:
) -> None | AsyncGenerator[None, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
)

View File

@@ -1,12 +1,13 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.star.star_handler import StarHandlerMetadata
from ..context import PipelineContext
from ..stage import Stage, register_stage
from .method.agent_request import AgentRequestSubStage
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
@@ -16,12 +17,9 @@ class ProcessStage(Stage):
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
await self.llm_request_sub_stage.initialize(ctx)
# initialize agent sub stage
self.agent_sub_stage = AgentRequestSubStage()
await self.agent_sub_stage.initialize(ctx)
# initialize star request sub stage
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
@@ -41,7 +39,7 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
async for _ in self.agent_sub_stage.process(event):
async for _ in self.llm_request_sub_stage.process(event):
_t = True
yield
if not _t:
@@ -62,5 +60,12 @@ class ProcessStage(Stage):
if (
event.get_result() and not event.get_result().is_stopped()
) or not event.get_result():
async for _ in self.agent_sub_stage.process(event):
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.llm_request_sub_stage.process(event):
yield

View File

@@ -1,64 +1,23 @@
from pydantic import Field
from pydantic.dataclasses import dataclass
from astrbot.api import logger, sp
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.star.context import Context
from astrbot.core.provider.entities import ProviderRequest
from ..context import PipelineContext
@dataclass
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
name: str = "astr_kb_search"
description: str = (
"Query the knowledge base for facts or relevant context. "
"Use this tool when the user's question requires factual information, "
"definitions, background knowledge, or previously indexed content. "
"Only send short keywords or a concise question as the query."
)
parameters: dict = Field(
default_factory=lambda: {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A concise keyword query for the knowledge base.",
},
},
"required": ["query"],
}
)
async def call(
self, context: ContextWrapper[AstrAgentContext], **kwargs
) -> ToolExecResult:
query = kwargs.get("query", "")
if not query:
return "error: Query parameter is empty."
result = await retrieve_knowledge_base(
query=kwargs.get("query", ""),
umo=context.context.event.unified_msg_origin,
context=context.context.context,
)
if not result:
return "No relevant knowledge found."
return result
async def retrieve_knowledge_base(
query: str,
async def inject_kb_context(
umo: str,
context: Context,
) -> str | None:
p_ctx: PipelineContext,
req: ProviderRequest,
) -> None:
"""Inject knowledge base context into the provider request
Args:
umo: Unique message object (session ID)
p_ctx: Pipeline context
req: Provider request
"""
kb_mgr = context.kb_manager
config = context.get_config(umo=umo)
kb_mgr = p_ctx.plugin_manager.context.kb_manager
# 1. 优先读取会话级配置
session_config = await sp.session_get(umo, "kb_config", default={})
@@ -95,18 +54,18 @@ async def retrieve_knowledge_base(
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
else:
kb_names = config.get("kb_names", [])
top_k = config.get("kb_final_top_k", 5)
kb_names = p_ctx.astrbot_config.get("kb_names", [])
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
top_k_fusion = config.get("kb_fusion_top_k", 20)
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
if not kb_names:
return
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
kb_context = await kb_mgr.retrieve(
query=query,
query=req.prompt,
kb_names=kb_names,
top_k_fusion=top_k_fusion,
top_m_final=top_k,
@@ -119,7 +78,4 @@ async def retrieve_knowledge_base(
if formatted:
results = kb_context.get("results", [])
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
return formatted
KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool()
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"

View File

@@ -1,7 +1,6 @@
import asyncio
import math
import random
from collections.abc import AsyncGenerator
import astrbot.core.message.components as Comp
from astrbot.core import logger
@@ -10,6 +9,7 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
from ..context import PipelineContext, call_event_hook
from ..stage import Stage, register_stage
@@ -152,7 +152,7 @@ class RespondStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
) -> None:
result = event.get_result()
if result is None:
return
@@ -168,15 +168,12 @@ class RespondStage(Stage):
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
realtime_segmenting = (
self.config.get("provider_settings", {}).get(
"unsupported_streaming_strategy",
"realtime_segmenting",
)
== "realtime_segmenting"
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented",
False,
)
logger.info(f"应用流式输出({event.get_platform_id()})")
await event.send_streaming(result.async_stream, realtime_segmenting)
await event.send_streaming(result.async_stream, use_fallback)
return
if len(result.chain) > 0:
# 检查路径映射
@@ -220,20 +217,21 @@ class RespondStage(Stage):
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
)
return
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
else:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}

View File

@@ -161,21 +161,11 @@ class ResultDecorateStage(Stage):
# 不分段回复
new_chain.append(comp)
continue
try:
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
)
except re.error:
logger.error(
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
)
split_response = re.findall(
r".*?[。?!~…]+|.+$",
comp.text,
re.DOTALL | re.MULTILINE,
)
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
)
if not split_response:
new_chain.append(comp)
continue

View File

@@ -16,6 +16,3 @@ class PlatformMetadata:
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str | None = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
support_streaming_message: bool = True
"""平台是否支持真实流式传输"""

View File

@@ -14,7 +14,6 @@ def register_platform_adapter(
default_config_tmpl: dict | None = None,
adapter_display_name: str | None = None,
logo_path: str | None = None,
support_streaming_message: bool = True,
):
"""用于注册平台适配器的带参装饰器。
@@ -43,7 +42,6 @@ def register_platform_adapter(
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
support_streaming_message=support_streaming_message,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -29,7 +29,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
@register_platform_adapter(
"aiocqhttp",
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
support_streaming_message=False,
)
class AiocqhttpAdapter(Platform):
def __init__(
@@ -50,7 +49,6 @@ class AiocqhttpAdapter(Platform):
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
support_streaming_message=False,
)
self.bot = CQHttp(
@@ -109,7 +107,7 @@ class AiocqhttpAdapter(Platform):
)
await super().send_by_session(session, message_chain)
async def convert_message(self, event: Event) -> AstrBotMessage | None:
async def convert_message(self, event: Event) -> AstrBotMessage:
logger.debug(f"[aiocqhttp] RawMessage {event}")
if event["post_type"] == "message":
@@ -224,7 +222,7 @@ class AiocqhttpAdapter(Platform):
err = f"aiocqhttp: 无法识别的消息类型: {event.message!s},此条消息将被忽略。如果您在使用 go-cqhttp请将其配置文件中的 message.post-format 更改为 array。"
logger.critical(err)
try:
await self.bot.send(event, err)
self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return None

View File

@@ -37,9 +37,7 @@ class MyEventHandler(dingtalk_stream.EventHandler):
return AckMessage.STATUS_OK, "OK"
@register_platform_adapter(
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
)
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
class DingtalkPlatformAdapter(Platform):
def __init__(
self,
@@ -76,14 +74,6 @@ class DingtalkPlatformAdapter(Platform):
)
self.client_ = client # 用于 websockets 的 client
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
if not dingtalk_id:
return dingtalk_id
prefix = "$:LWCP_v1:$"
if dingtalk_id.startswith(prefix):
return dingtalk_id[len(prefix) :]
return dingtalk_id
async def send_by_session(
self,
session: MessageSesion,
@@ -96,7 +86,6 @@ class DingtalkPlatformAdapter(Platform):
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
support_streaming_message=False,
)
async def convert_msg(
@@ -113,10 +102,10 @@ class DingtalkPlatformAdapter(Platform):
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
user_id=self._id_to_sid(message.sender_id),
user_id=message.sender_id,
nickname=message.sender_nick,
)
abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.self_id = message.chatbot_user_id
abm.message_id = message.message_id
abm.raw_message = message
@@ -124,8 +113,8 @@ class DingtalkPlatformAdapter(Platform):
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
if message.at_users:
for user in message.at_users:
if id := self._id_to_sid(user.dingtalk_id):
abm.message.append(At(qq=id))
if user.dingtalk_id:
abm.message.append(At(qq=user.dingtalk_id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
@@ -250,7 +239,7 @@ class DingtalkPlatformAdapter(Platform):
async def terminate(self):
def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown")
raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")

View File

@@ -34,9 +34,7 @@ else:
# 注册平台适配器
@register_platform_adapter(
"discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
)
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
class DiscordPlatformAdapter(Platform):
def __init__(
self,
@@ -92,7 +90,7 @@ class DiscordPlatformAdapter(Platform):
)
message_obj.self_id = self.client_self_id
message_obj.session_id = session.session_id
message_obj.message = message_chain.chain
message_obj.message = message_chain
# 创建临时事件对象来发送消息
temp_event = DiscordPlatformEvent(
@@ -113,7 +111,6 @@ class DiscordPlatformAdapter(Platform):
"Discord 适配器",
id=self.config.get("id"),
default_config_tmpl=self.config,
support_streaming_message=False,
)
@override

View File

@@ -1,7 +1,6 @@
import asyncio
import base64
import binascii
from collections.abc import AsyncGenerator
import sys
from io import BytesIO
from pathlib import Path
@@ -21,6 +20,11 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata
from .client import DiscordBotClient
from .components import DiscordEmbed, DiscordView
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 自定义Discord视图组件兼容旧版本
class DiscordViewComponent(BaseMessageComponent):
@@ -44,6 +48,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.client = client
self.interaction_followup_webhook = interaction_followup_webhook
@override
async def send(self, message: MessageChain):
"""发送消息到Discord平台"""
# 解析消息链为 Discord 所需的对象
@@ -92,21 +97,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None:
"""获取当前事件对应的频道对象"""
try:
@@ -193,7 +183,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
BytesIO(img_bytes),
filename=filename or "image.png",
)
except (ValueError, TypeError, binascii.Error):
except (ValueError, TypeError, base64.binascii.Error):
logger.debug(
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}",
)

View File

@@ -23,9 +23,7 @@ from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent
@register_platform_adapter(
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
)
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
class LarkPlatformAdapter(Platform):
def __init__(
self,
@@ -117,7 +115,6 @@ class LarkPlatformAdapter(Platform):
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):

View File

@@ -45,9 +45,7 @@ MAX_FILE_UPLOAD_COUNT = 16
DEFAULT_UPLOAD_CONCURRENCY = 3
@register_platform_adapter(
"misskey", "Misskey 平台适配器", support_streaming_message=False
)
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self,
@@ -122,7 +120,6 @@ class MisskeyPlatformAdapter(Platform):
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
support_streaming_message=False,
)
async def run(self):

View File

@@ -29,7 +29,8 @@ from astrbot.core.platform.astr_message_event import MessageSession
@register_platform_adapter(
"satori", "Satori 协议适配器", support_streaming_message=False
"satori",
"Satori 协议适配器",
)
class SatoriPlatformAdapter(Platform):
def __init__(
@@ -59,7 +60,6 @@ class SatoriPlatformAdapter(Platform):
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
support_streaming_message=False,
)
self.ws: ClientConnection | None = None

View File

@@ -30,7 +30,6 @@ from .slack_event import SlackMessageEvent
@register_platform_adapter(
"slack",
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
support_streaming_message=False,
)
class SlackAdapter(Platform):
def __init__(
@@ -69,7 +68,6 @@ class SlackAdapter(Platform):
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
support_streaming_message=False,
)
# 初始化 Slack Web Client
@@ -84,7 +82,7 @@ class SlackAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
):
blocks, text = await SlackMessageEvent._parse_slack_blocks(
blocks, text = SlackMessageEvent._parse_slack_blocks(
message_chain=message_chain,
web_client=self.web_client,
)

View File

@@ -5,6 +5,8 @@ import uuid
from collections.abc import Awaitable, Callable
from typing import Any
from astrbot_api.abc import IAstrbotPaths
from astrbot import logger
from astrbot.core.message.components import Image, Plain, Record
from astrbot.core.message.message_event_result import MessageChain
@@ -16,12 +18,13 @@ from astrbot.core.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_sdk import sync_base_container
from ...register import register_platform_adapter
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
class QueueListener:
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
@@ -79,7 +82,7 @@ class WebChatAdapter(Platform):
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
self.imgs_dir = str(AstrbotPaths.astrbot_root / "webchat" / "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata(
@@ -163,9 +166,6 @@ class WebChatAdapter(Platform):
_, _, payload = message.raw_message # type: ignore
message_event.set_extra("selected_provider", payload.get("selected_provider"))
message_event.set_extra("selected_model", payload.get("selected_model"))
message_event.set_extra(
"enable_streaming", payload.get("enable_streaming", True)
)
self.commit_event(message_event)

View File

@@ -109,7 +109,6 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
reasoning_content = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
@@ -125,22 +124,16 @@ class WebChatMessageEvent(AstrMessageEvent):
)
final_data = ""
continue
r = await WebChatMessageEvent._send(
final_data += await WebChatMessageEvent._send(
chain,
session_id=self.session_id,
streaming=True,
)
if chain.type == "reasoning":
reasoning_content += chain.get_plain_text()
else:
final_data += r
await web_chat_back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
"reasoning": reasoning_content,
"streaming": True,
"cid": cid,
},

View File

@@ -8,10 +8,14 @@ import traceback
import aiohttp
import anyio
import websockets
from astrbot_api.abc import IAstrbotPaths
from astrbot import logger
from astrbot.api.message_components import At, Image, Plain, Record
from astrbot.api.platform import Platform, PlatformMetadata
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.astrbot_message import (
@@ -19,7 +23,6 @@ from astrbot.core.platform.astrbot_message import (
MessageMember,
MessageType,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
from .wechatpadpro_message_event import WeChatPadProMessageEvent
@@ -32,9 +35,7 @@ except ImportError as e:
)
@register_platform_adapter(
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
)
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
class WeChatPadProAdapter(Platform):
def __init__(
self,
@@ -53,7 +54,6 @@ class WeChatPadProAdapter(Platform):
name="wechatpadpro",
description="WeChatPadPro 消息平台适配器",
id=self.config.get("id", "wechatpadpro"),
support_streaming_message=False,
)
# 保存配置信息
@@ -71,9 +71,8 @@ class WeChatPadProAdapter(Platform):
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
self.credentials_file = str(
AstrbotPaths.astrbot_root / "wechatpadpro_credentials.json"
) # 持久化文件路径
self.ws_handle_task = None
@@ -158,8 +157,8 @@ class WeChatPadProAdapter(Platform):
}
try:
# 确保数据目录存在
data_dir = os.path.dirname(self.credentials_file)
os.makedirs(data_dir, exist_ok=True)
config_dir = AstrbotPaths.astrbot_root / "config"
config_dir.mkdir(parents=True, exist_ok=True)
with open(self.credentials_file, "w") as f:
json.dump(credentials, f)
except Exception as e:
@@ -790,10 +789,10 @@ class WeChatPadProAdapter(Platform):
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir,
f"wechatpadpro_voice_{abm.message_id}.silk",
file_path = str(
AstrbotPaths.astrbot_root
/ "temp"
/ f"wechatpadpro_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:

View File

@@ -1,7 +1,6 @@
import asyncio
import base64
import io
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
import aiohttp
@@ -51,21 +50,6 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
await self._send_voice(session, comp)
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
b64 = await comp.convert_to_base64()
raw = self._validate_base64(b64)

View File

@@ -110,7 +110,7 @@ class WecomServer:
await self.shutdown_event.wait()
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
@register_platform_adapter("wecom", "wecom 适配器")
class WecomPlatformAdapter(Platform):
def __init__(
self,
@@ -196,7 +196,6 @@ class WecomPlatformAdapter(Platform):
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
support_streaming_message=False,
)
@override

View File

@@ -10,7 +10,7 @@ import base64
import hashlib
import json
import logging
import secrets
import random
import socket
import struct
import time
@@ -139,12 +139,6 @@ class PKCS7Encoder:
class Prpcrypt:
"""提供接收和推送给企业微信消息的加解密接口"""
# 16位随机字符串的范围常量
# randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999]两端都包含即包含0和8999999999999999
# 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999]两端都包含即16位数字
MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位)
RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位)
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
@@ -213,9 +207,7 @@ class Prpcrypt:
"""随机生成16位字符串
@return: 16位字符串
"""
return str(
secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE
).encode()
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizJsonMsgCrypt:

View File

@@ -30,7 +30,7 @@ from .wecomai_api import (
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_queue_mgr import WecomAIQueueMgr, wecomai_queue_mgr
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import (
WecomAIBotConstants,
@@ -144,12 +144,9 @@ class WecomAIBotAdapter(Platform):
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
# 队列管理器
self.queue_mgr = WecomAIQueueMgr()
# 队列监听器
self.queue_listener = WecomAIQueueListener(
self.queue_mgr,
wecomai_queue_mgr,
self._handle_queued_message,
)
@@ -192,7 +189,7 @@ class WecomAIBotAdapter(Platform):
stream_id,
session_id,
)
self.queue_mgr.set_pending_response(stream_id, callback_params)
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id,
@@ -210,7 +207,7 @@ class WecomAIBotAdapter(Platform):
elif msgtype == "stream":
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not self.queue_mgr.has_back_queue(stream_id):
if not wecomai_queue_mgr.has_back_queue(stream_id):
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
# 返回结束标志,告诉微信服务器流已结束
@@ -225,7 +222,7 @@ class WecomAIBotAdapter(Platform):
callback_params["timestamp"],
)
return resp
queue = self.queue_mgr.get_or_create_back_queue(stream_id)
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if queue.empty():
logger.debug(
f"No new messages in back queue for stream_id: {stream_id}",
@@ -245,9 +242,10 @@ class WecomAIBotAdapter(Platform):
elif msg["type"] == "end":
# stream end
finish = True
self.queue_mgr.remove_queues(stream_id)
wecomai_queue_mgr.remove_queues(stream_id)
break
else:
pass
logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}",
)
@@ -315,8 +313,8 @@ class WecomAIBotAdapter(Platform):
session_id: str,
):
"""将消息放入队列进行异步处理"""
input_queue = self.queue_mgr.get_or_create_queue(stream_id)
_ = self.queue_mgr.get_or_create_back_queue(stream_id)
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
message_payload = {
"message_data": message_data,
"callback_params": callback_params,
@@ -455,7 +453,6 @@ class WecomAIBotAdapter(Platform):
platform_meta=self.meta(),
session_id=message.session_id,
api_client=self.api_client,
queue_mgr=self.queue_mgr,
)
self.commit_event(message_event)

View File

@@ -8,7 +8,7 @@ from astrbot.api.message_components import (
)
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_queue_mgr import wecomai_queue_mgr
class WecomAIBotMessageEvent(AstrMessageEvent):
@@ -21,7 +21,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
queue_mgr: WecomAIQueueMgr,
):
"""初始化消息事件
@@ -35,16 +34,14 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"""
super().__init__(message_str, message_obj, platform_meta, session_id)
self.api_client = api_client
self.queue_mgr = queue_mgr
@staticmethod
async def _send(
message_chain: MessageChain,
stream_id: str,
queue_mgr: WecomAIQueueMgr,
streaming: bool = False,
):
back_queue = queue_mgr.get_or_create_back_queue(stream_id)
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if not message_chain:
await back_queue.put(
@@ -97,7 +94,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await WecomAIBotMessageEvent._send(message, stream_id)
await super().send(message)
async def send_streaming(self, generator, use_fallback=False):
@@ -108,7 +105,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
increment_plain = ""
@@ -137,7 +134,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
queue_mgr=self.queue_mgr,
streaming=True,
)

View File

@@ -151,3 +151,7 @@ class WecomAIQueueMgr:
"output_queues": len(self.back_queues),
"pending_responses": len(self.pending_responses),
}
# 全局队列管理器实例
wecomai_queue_mgr = WecomAIQueueMgr()

View File

@@ -5,7 +5,7 @@
import asyncio
import base64
import hashlib
import secrets
import random
import string
from typing import Any
@@ -53,7 +53,7 @@ def generate_random_string(length: int = 10) -> str:
"""
letters = string.ascii_letters + string.digits
return "".join(secrets.choice(letters) for _ in range(length))
return "".join(random.choice(letters) for _ in range(length))
def calculate_image_md5(image_data: bytes) -> str:

View File

@@ -113,9 +113,7 @@ class WecomServer:
await self.shutdown_event.wait()
@register_platform_adapter(
"weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
)
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
class WeixinOfficialAccountPlatformAdapter(Platform):
def __init__(
self,
@@ -197,7 +195,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
support_streaming_message=False,
)
@override

View File

@@ -1,4 +1,4 @@
from .entities import ProviderMetaData
from .provider import Provider, STTProvider
from .provider import Personality, Provider, STTProvider
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]

View File

@@ -30,31 +30,18 @@ class ProviderType(enum.Enum):
@dataclass
class ProviderMeta:
"""The basic metadata of a provider instance."""
id: str
"""the unique id of the provider instance that user configured"""
model: str | None
"""the model name of the provider instance currently used"""
class ProviderMetaData:
type: str
"""the name of the provider adapter, such as openai, ollama"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
"""the capability type of the provider adapter"""
@dataclass
class ProviderMetaData(ProviderMeta):
"""The metadata of a provider adapter for registration."""
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""the short description of the provider adapter"""
"""提供商适配器描述"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Any = None
"""the class type of the provider adapter"""
default_config_tmpl: dict | None = None
"""the default configuration template of the provider adapter"""
"""平台的默认配置模板"""
provider_display_name: str | None = None
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
@@ -73,20 +60,12 @@ class ToolCallsResult:
]
return ret
def to_openai_messages_model(
self,
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
return [
self.tool_calls_info,
*self.tool_calls_result,
]
@dataclass
class ProviderRequest:
prompt: str | None = None
prompt: str
"""提示词"""
session_id: str | None = ""
session_id: str = ""
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
@@ -202,30 +181,25 @@ class ProviderRequest:
@dataclass
class LLMResponse:
role: str
"""The role of the message, e.g., assistant, tool, err"""
"""角色, assistant, tool, err"""
result_chain: MessageChain | None = None
"""A chain of message components representing the text completion from LLM."""
"""返回的消息链"""
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
"""Tool call arguments."""
"""工具调用参数"""
tools_call_name: list[str] = field(default_factory=list)
"""Tool call names."""
"""工具调用名称"""
tools_call_ids: list[str] = field(default_factory=list)
"""Tool call IDs."""
tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Tool call extra content. tool_call_id -> extra_content dict"""
reasoning_content: str = ""
"""The reasoning content extracted from the LLM, if any."""
"""工具调用 ID"""
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
) = None
"""The raw completion response from the LLM provider."""
_new_record: dict[str, Any] | None = None
_completion_text: str = ""
"""The plain text of the completion."""
is_chunk: bool = False
"""Indicates if the response is a chunked response."""
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
@@ -235,11 +209,11 @@ class LLMResponse:
tools_call_args: list[dict[str, Any]] | None = None,
tools_call_name: list[str] | None = None,
tools_call_ids: list[str] | None = None,
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
raw_completion: ChatCompletion
| GenerateContentResponse
| AnthropicMessage
| None = None,
_new_record: dict[str, Any] | None = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -259,8 +233,6 @@ class LLMResponse:
tools_call_name = []
if tools_call_ids is None:
tools_call_ids = []
if tools_call_extra_content is None:
tools_call_extra_content = {}
self.role = role
self.completion_text = completion_text
@@ -268,8 +240,8 @@ class LLMResponse:
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.tools_call_extra_content = tools_call_extra_content
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property
@@ -294,19 +266,16 @@ class LLMResponse:
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
payload = {
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
ret.append(
{
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
},
"type": "function",
}
if self.tools_call_extra_content.get(self.tools_call_ids[idx]):
payload["extra_content"] = self.tools_call_extra_content[
self.tools_call_ids[idx]
]
ret.append(payload)
)
return ret
def to_openai_to_calls_model(self) -> list[ToolCall]:
@@ -320,10 +289,6 @@ class LLMResponse:
name=self.tools_call_name[idx],
arguments=json.dumps(tool_call_arg),
),
# the extra_content will not serialize if it's None when calling ToolCall.model_dump()
extra_content=self.tools_call_extra_content.get(
self.tools_call_ids[idx]
),
),
)
return ret

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
@@ -25,16 +24,7 @@ SUPPORTED_TYPES = [
"boolean",
] # json schema 支持的数据类型
PY_TO_JSON_TYPE = {
"int": "number",
"float": "number",
"bool": "boolean",
"str": "string",
"dict": "object",
"list": "array",
"tuple": "array",
"set": "array",
}
# alias
FuncTool = FunctionTool
@@ -116,7 +106,7 @@ class FunctionToolManager:
def spec_to_func(
self,
name: str,
func_args: list[dict],
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
@@ -125,9 +115,10 @@ class FunctionToolManager:
"properties": {},
}
for param in func_args:
p = copy.deepcopy(param)
p.pop("name", None)
params["properties"][param["name"]] = p
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
return FuncTool(
name=name,
parameters=params,
@@ -280,22 +271,19 @@ class FunctionToolManager:
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
client = self.mcp_client_dict[name]
try:
# 关闭MCP连接
await client.cleanup()
await self.mcp_client_dict[name].cleanup()
self.mcp_client_dict.pop(name)
except Exception as e:
logger.error(f"清空 MCP 客户端资源 {name}: {e}")
finally:
# Remove client from dict after cleanup attempt (successful or not)
self.mcp_client_dict.pop(name, None)
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:

View File

@@ -1,7 +1,7 @@
import asyncio
import traceback
from astrbot.core import astrbot_config, logger, sp
from astrbot.core import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
@@ -24,7 +24,6 @@ class ProviderManager:
db_helper: BaseDatabase,
persona_mgr: PersonaManager,
):
self.reload_lock = asyncio.Lock()
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
@@ -227,9 +226,6 @@ class ProviderManager:
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return
if provider_config.get("provider_type", "") == "agent_runner":
return
logger.info(
@@ -245,12 +241,18 @@ class ProviderManager:
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
@@ -327,10 +329,6 @@ class ProviderManager:
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
@@ -356,8 +354,6 @@ class ProviderManager:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
@@ -398,6 +394,7 @@ class ProviderManager:
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
)
if getattr(inst, "initialize", None):
@@ -436,46 +433,40 @@ class ProviderManager:
)
async def reload(self, provider_config: dict):
async with self.reload_lock:
await self.terminate_provider(provider_config["id"])
if provider_config["enable"]:
await self.load_provider(provider_config)
await self.terminate_provider(provider_config["id"])
if provider_config["enable"]:
await self.load_provider(provider_config)
# 和配置文件保持同步
self.providers_config = astrbot_config["provider"]
config_ids = [provider["id"] for provider in self.providers_config]
logger.info(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.providers_config]
logger.debug(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
)
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
):
self.curr_stt_provider_inst = self.stt_provider_insts[0]
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
):
self.curr_tts_provider_inst = self.tts_provider_insts[0]
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
)
def get_insts(self):
return self.provider_insts

View File

@@ -1,18 +1,28 @@
import abc
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Personality
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMeta,
ProviderType,
RerankResult,
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
@dataclass
class ProviderMeta:
id: str
model: str
type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
@@ -33,15 +43,15 @@ class AbstractProvider(abc.ABC):
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
if not meta_data:
raise ValueError(f"Provider type {provider_type_name} not registered")
meta = ProviderMeta(
id=self.provider_config.get("id", "default"),
provider_type = meta_data.provider_type if meta_data else None
if provider_type is None:
raise ValueError(f"Cannot find provider type: {provider_type_name}")
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
type=provider_type_name,
provider_type=meta_data.provider_type,
provider_type=provider_type,
)
return meta
class Provider(AbstractProvider):
@@ -51,10 +61,15 @@ class Provider(AbstractProvider):
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality = default_persona
"""维护了当前的使用的 persona即人格。可能为 None"""
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError

View File

@@ -36,8 +36,6 @@ def register_provider_adapter(
default_config_tmpl["id"] = provider_type_name
pm = ProviderMetaData(
id="default", # will be replaced when instantiated
model=None,
type=provider_type_name,
desc=desc,
provider_type=provider_type,

View File

@@ -25,10 +25,12 @@ class ProviderAnthropic(Provider):
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.chosen_api_key: str = ""
@@ -290,7 +292,7 @@ class ProviderAnthropic(Provider):
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response

View File

@@ -1,8 +1,8 @@
import asyncio
import hashlib
import json
import random
import re
import secrets
import time
import uuid
from pathlib import Path
@@ -54,9 +54,7 @@ class OTTSProvider:
async def _generate_signature(self) -> str:
await self._sync_time()
timestamp = int(time.time()) + self.time_offset
nonce = "".join(
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
)
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"

View File

@@ -1,236 +0,0 @@
import os
import aiohttp
from astrbot import logger
from ..entities import ProviderType, RerankResult
from ..provider import RerankProvider
from ..register import register_provider_adapter
class BailianRerankError(Exception):
"""百炼重排序服务异常基类"""
pass
class BailianAPIError(BailianRerankError):
"""百炼API返回错误"""
pass
class BailianNetworkError(BailianRerankError):
"""百炼网络请求错误"""
pass
@register_provider_adapter(
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
)
class BailianRerankProvider(RerankProvider):
"""阿里云百炼文本重排序适配器."""
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
# API配置
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
"DASHSCOPE_API_KEY", ""
)
if not self.api_key:
raise ValueError("阿里云百炼 API Key 不能为空。")
self.model = provider_config.get("rerank_model", "qwen3-rerank")
self.timeout = provider_config.get("timeout", 30)
self.return_documents = provider_config.get("return_documents", False)
self.instruct = provider_config.get("instruct", "")
self.base_url = provider_config.get(
"rerank_api_base",
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
)
# 设置HTTP客户端
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
self.client = aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
)
# 设置模型名称
self.set_model(self.model)
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
def _build_payload(
self, query: str, documents: list[str], top_n: int | None
) -> dict:
"""构建请求载荷
Args:
query: 查询文本
documents: 文档列表
top_n: 返回前N个结果如果为None则返回所有结果
Returns:
请求载荷字典
"""
base = {"model": self.model, "input": {"query": query, "documents": documents}}
params = {
k: v
for k, v in [
("top_n", top_n if top_n is not None and top_n > 0 else None),
("return_documents", True if self.return_documents else None),
(
"instruct",
self.instruct
if self.instruct and self.model == "qwen3-rerank"
else None,
),
]
if v is not None
}
if params:
base["parameters"] = params
return base
def _parse_results(self, data: dict) -> list[RerankResult]:
"""解析API响应结果
Args:
data: API响应数据
Returns:
重排序结果列表
Raises:
BailianAPIError: API返回错误
KeyError: 结果缺少必要字段
"""
# 检查响应状态
if data.get("code", "200") != "200":
raise BailianAPIError(
f"百炼 API 错误: {data.get('code')} {data.get('message', '')}"
)
results = data.get("output", {}).get("results", [])
if not results:
logger.warning(f"百炼 Rerank 返回空结果: {data}")
return []
# 转换为RerankResult对象使用.get()避免KeyError
rerank_results = []
for idx, result in enumerate(results):
try:
index = result.get("index", idx)
relevance_score = result.get("relevance_score", 0.0)
if relevance_score is None:
logger.warning(f"结果 {idx} 缺少 relevance_score使用默认值 0.0")
relevance_score = 0.0
rerank_result = RerankResult(
index=index, relevance_score=relevance_score
)
rerank_results.append(rerank_result)
except Exception as e:
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
continue
return rerank_results
def _log_usage(self, data: dict) -> None:
"""记录使用量信息
Args:
data: API响应数据
"""
tokens = data.get("usage", {}).get("total_tokens", 0)
if tokens > 0:
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回前N个结果如果为None则使用配置中的默认值
Returns:
重排序结果列表
"""
if not documents:
logger.warning("文档列表为空,返回空结果")
return []
if not query.strip():
logger.warning("查询文本为空,返回空结果")
return []
# 检查限制
if len(documents) > 500:
logger.warning(
f"文档数量({len(documents)})超过限制(500)将截断前500个文档"
)
documents = documents[:500]
try:
# 构建请求载荷如果top_n为None则返回所有重排序结果
payload = self._build_payload(query, documents, top_n)
logger.debug(
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
)
# 发送请求
async with self.client.post(self.base_url, json=payload) as response:
response.raise_for_status()
response_data = await response.json()
# 解析结果并记录使用量
results = self._parse_results(response_data)
self._log_usage(response_data)
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
return results
except aiohttp.ClientError as e:
error_msg = f"网络请求失败: {e}"
logger.error(f"百炼 Rerank 网络请求失败: {e}")
raise BailianNetworkError(error_msg) from e
except BailianRerankError:
raise
except Exception as e:
error_msg = f"重排序失败: {e}"
logger.error(f"百炼 Rerank 处理失败: {e}")
raise BailianRerankError(error_msg) from e
async def terminate(self) -> None:
"""关闭HTTP客户端会话."""
if self.client:
logger.info("关闭 百炼 Rerank 客户端会话")
try:
await self.client.close()
except Exception as e:
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
finally:
self.client = None

View File

@@ -0,0 +1,652 @@
import base64
import hashlib
import json
import os
from collections.abc import AsyncGenerator
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: dict[str, str] = {}
self.file_id_cache: dict[str, dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
elif data.startswith(("http://", "https://")):
# URL图片使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self,
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
raise Exception(f"处理图片失败: {e!s}")
async def _process_context_images(
self,
content: str | list,
session_id: str,
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data,
is_base64=image_data.startswith("data:image/"),
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id},
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data,
session_id,
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}...",
)
continue
processed_content.append(
{"type": "image", "file_id": file_id},
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {e!s}")
if isinstance(content, str):
return content
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{
"role": "system",
"content": system_prompt,
"content_type": "text",
},
)
contexts = self._ensure_message_to_dicts(contexts)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content,
user_id,
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
},
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
},
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url,
user_id,
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url,
is_base64=True,
)
file_id = await self._upload_file(
image_data,
user_id,
cache_key,
)
# 本地文件
elif os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data,
user_id,
cache_key,
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
},
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {e!s}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
},
)
# 纯文本
elif prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
},
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {e!s}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {e!s}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {e!s}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self,
session_id: str,
page: int = 1,
page_size: int = 10,
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {e!s}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

View File

@@ -0,0 +1,209 @@
import asyncio
import functools
import re
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from .. import Personality, Provider
from ..entities import LLMResponse
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
class ProviderDashscope(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
if contexts is None:
contexts = []
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
contexts_no_img = await self._remove_image_from_context(contexts)
context_query = [*contexts_no_img, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"api_key": self.api_key,
"messages": context_query,
"biz_params": payload_vars or None,
}
partial = functools.partial(
Application.call,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
partial = functools.partial(
Application.call,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
assert isinstance(response, ApplicationResponse)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
)
return LLMResponse(
role="err",
result_chain=MessageChain().message(
f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
),
)
output_text = response.output.get("text", "") or ""
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_parts = []
for ref in response.output.get("doc_references", []) or []:
ref_title = (
ref.get("title", "")
if ref.get("title")
else ref.get("doc_name", "")
)
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
ref_str = "".join(ref_parts)
output_text += f"\n\n回答来源:\n{ref_str}"
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(output_text)
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def forget(self, session_id):
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
async def terminate(self):
pass

View File

@@ -15,7 +15,11 @@ except (
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot_api.abc import IAstrbotPaths
from astrbot_sdk import sync_base_container
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -45,7 +49,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):

View File

@@ -0,0 +1,287 @@
import os
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_file, download_image_by_url
from .. import Provider
from ..entities import LLMResponse
from ..register import register_provider_adapter
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
self.model_name = "dify"
self.workflow_output_key = provider_config.get(
"dify_workflow_output_key",
"astrbot_wf_output",
)
self.dify_query_input_key = provider_config.get(
"dify_query_input_key",
"astrbot_text_query",
)
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
if not self.workflow_output_key:
self.workflow_output_key = "astrbot_wf_output"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.conversation_ids = {}
"""记录当前 session id 的对话 ID"""
self.api_client = DifyAPIClient(self.api_key, api_base)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
result = ""
session_id = session_id or kwargs.get("user") or "unknown" # 1734
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
for image_url in image_urls:
image_path = (
await download_image_by_url(image_url)
if image_url.startswith("http")
else image_url
)
file_response = await self.api_client.file_upload(
image_path,
user=session_id,
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
},
)
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
try:
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if (
chunk["event"] == "message"
or chunk["event"] == "agent_message"
):
result += chunk["answer"]
if not conversation_id:
self.conversation_ids[session_id] = chunk[
"conversation_id"
]
conversation_id = chunk["conversation_id"]
elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
raise Exception(
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
)
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout,
):
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。",
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。",
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}",
)
raise Exception(
f"Dify 工作流出现错误:{chunk['data']['error']}",
)
if (
self.workflow_output_key
not in chunk["data"]["outputs"]
):
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
)
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
except Exception as e:
logger.error(f"Dify 请求失败:{e!s}")
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}")
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
chain = await self.parse_dify_result(result)
return LLMResponse(role="assistant", result_chain=chain)
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))
elif isinstance(output, list):
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
for item in output:
# handle Array[File]
if (
not isinstance(item, dict)
or item.get("dify_model_identity", "") != "__dify__file__"
):
chains.append(Comp.Plain(str(output)))
break
else:
chains.append(Comp.Plain(str(output)))
# scan file
files = chunk["data"].get("files", [])
for item in files:
comp = await parse_file(item)
chains.append(comp)
return MessageChain(chain=chains)
async def forget(self, session_id):
self.conversation_ids[session_id] = ""
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("Dify 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 Dify 的历史消息记录。")
async def terminate(self):
await self.api_client.close()

Some files were not shown because too many files have changed in this diff Show More