Compare commits
300 Commits
v3.5.17
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6b6eef8c4 | ||
|
|
50cf263076 | ||
|
|
2554548088 | ||
|
|
aa4a2d10e2 | ||
|
|
02a9769b35 | ||
|
|
7640f11bfc | ||
|
|
9fa44dbcfa | ||
|
|
2cae941bae | ||
|
|
bc0784f41d | ||
|
|
c57d75e01a | ||
|
|
73edeae013 | ||
|
|
7d46314dc8 | ||
|
|
d5a53a89eb | ||
|
|
a85bc510dd | ||
|
|
2beea7d218 | ||
|
|
a93cd3dd5f | ||
|
|
db4d02c2e2 | ||
|
|
fd7811402b | ||
|
|
eb0325e627 | ||
|
|
8b4b04ec09 | ||
|
|
9f32c9280f | ||
|
|
4fcd09cfa8 | ||
|
|
7a8d65d37d | ||
|
|
23129a9ba2 | ||
|
|
7f791e730b | ||
|
|
f7e296b349 | ||
|
|
712d4acaaa | ||
|
|
74a5c01f21 | ||
|
|
3ba8724d77 | ||
|
|
6313a7d8a9 | ||
|
|
432a3f520c | ||
|
|
191b3e42d4 | ||
|
|
a27f05fcb4 | ||
|
|
2f33e0b873 | ||
|
|
f0359467f1 | ||
|
|
d1db8cf2c8 | ||
|
|
b1985ed2ce | ||
|
|
140ddc70e6 | ||
|
|
d7fd616470 | ||
|
|
3ccbef141e | ||
|
|
e92fbb0443 | ||
|
|
bd270aed68 | ||
|
|
28d7864393 | ||
|
|
b5d8173ee3 | ||
|
|
17d62a9af7 | ||
|
|
d89fb863ed | ||
|
|
a21ad77820 | ||
|
|
f86c8e8cab | ||
|
|
cb12cbdd3d | ||
|
|
6661fa996c | ||
|
|
c19bca798b | ||
|
|
8f98b411db | ||
|
|
a8aa03847e | ||
|
|
1bfd747cc6 | ||
|
|
ae06d945a7 | ||
|
|
9f41d5f34d | ||
|
|
ef61c52908 | ||
|
|
d8842ef274 | ||
|
|
c88fdaf353 | ||
|
|
af295da871 | ||
|
|
083235a2fe | ||
|
|
2a3a5f7eb2 | ||
|
|
77c48f280f | ||
|
|
0ee1eb2f9f | ||
|
|
c2b20365bb | ||
|
|
cfdc7e4452 | ||
|
|
2363f61aa9 | ||
|
|
557ac6f9fa | ||
|
|
a49b871cf9 | ||
|
|
a0d6b3efba | ||
|
|
6cabf07bc0 | ||
|
|
a15444ee8c | ||
|
|
ceb5f5669e | ||
|
|
25b75e05e4 | ||
|
|
4d214bb5c1 | ||
|
|
7cbaed8c6c | ||
|
|
2915fdf665 | ||
|
|
a66c385b08 | ||
|
|
4dace7c5d8 | ||
|
|
8ebf087dbf | ||
|
|
2fa8bda5bb | ||
|
|
a5ae833945 | ||
|
|
d21d42b312 | ||
|
|
78575f0f0a | ||
|
|
8ccd292d16 | ||
|
|
2534f59398 | ||
|
|
5c60dbe2b1 | ||
|
|
c99ecde15f | ||
|
|
219f3403d9 | ||
|
|
00f417bad6 | ||
|
|
81649f053b | ||
|
|
e5bde50f2d | ||
|
|
0321e00b0d | ||
|
|
09528e3292 | ||
|
|
e7412a9cbf | ||
|
|
01efe5f869 | ||
|
|
28a178a55c | ||
|
|
88f130014c | ||
|
|
af258c590c | ||
|
|
b0eb5733be | ||
|
|
fe35bfba37 | ||
|
|
7cfbc4ab8f | ||
|
|
7a9d4f0abd | ||
|
|
6f6a5b565c | ||
|
|
e57deb873c | ||
|
|
0f692b1608 | ||
|
|
8c03e79f99 | ||
|
|
71290f0929 | ||
|
|
22364ef7de | ||
|
|
2cc1eb1abc | ||
|
|
90dbcbb4e2 | ||
|
|
66503d58be | ||
|
|
8e10f0ce2b | ||
|
|
f51f510f2e | ||
|
|
c44f085b47 | ||
|
|
a35f36eeaf | ||
|
|
14564c392a | ||
|
|
76e05ea749 | ||
|
|
ab599dceed | ||
|
|
4c37604445 | ||
|
|
bb74018d19 | ||
|
|
575289e5bc | ||
|
|
e89da2a7b4 | ||
|
|
bd34959f68 | ||
|
|
622dcf8fd5 | ||
|
|
9e315739b7 | ||
|
|
7b01adc5df | ||
|
|
432fc47443 | ||
|
|
d8fba44c5e | ||
|
|
e29d3d8c01 | ||
|
|
e678413214 | ||
|
|
eaa9d9d087 | ||
|
|
9e3cc076b7 | ||
|
|
3bb01fa52c | ||
|
|
008e49d144 | ||
|
|
4e275384b0 | ||
|
|
63ec99f67a | ||
|
|
14a8bb57df | ||
|
|
7512bfc710 | ||
|
|
3c3b6dadc3 | ||
|
|
cd722a0e39 | ||
|
|
a1b5d0a100 | ||
|
|
69d3ae709c | ||
|
|
67ef993d61 | ||
|
|
20f49890ad | ||
|
|
3e4917f0a1 | ||
|
|
99ee75aec6 | ||
|
|
1674653a42 | ||
|
|
d2f7e55bf5 | ||
|
|
9f31df7f3a | ||
|
|
b8c1b53d67 | ||
|
|
2495837791 | ||
|
|
b6562e3c47 | ||
|
|
c57da046ee | ||
|
|
ff63134c14 | ||
|
|
3f5210c587 | ||
|
|
3df5e7b9b9 | ||
|
|
225db66738 | ||
|
|
383ebb8f57 | ||
|
|
e1bed60f1f | ||
|
|
edbb856023 | ||
|
|
98d3ab646f | ||
|
|
81be556f1b | ||
|
|
f45a085469 | ||
|
|
210cc58cc3 | ||
|
|
1063b11ef6 | ||
|
|
a4e999c47f | ||
|
|
543e01c301 | ||
|
|
14e0aa3ec5 | ||
|
|
1a8a171f8b | ||
|
|
f1954f9a43 | ||
|
|
441b148501 | ||
|
|
bd0f30b81c | ||
|
|
ad14e9bf40 | ||
|
|
6f71301aaf | ||
|
|
5f0d601baa | ||
|
|
f234a5bcc2 | ||
|
|
ab677ea100 | ||
|
|
f3ad53e949 | ||
|
|
d324cfa84d | ||
|
|
dd4319d72a | ||
|
|
1f2de3d3d8 | ||
|
|
72702beb0b | ||
|
|
adb0cbc5dd | ||
|
|
6a503b82c3 | ||
|
|
28a87351f1 | ||
|
|
bcc97378b0 | ||
|
|
eb8a138713 | ||
|
|
dcd7dcbbdf | ||
|
|
1538759ba7 | ||
|
|
30e8ea7fd8 | ||
|
|
879b7b582c | ||
|
|
8ba4236402 | ||
|
|
5eef8fa9b9 | ||
|
|
d03d035437 | ||
|
|
68e8e1f70b | ||
|
|
7acb45b157 | ||
|
|
c36142deaf | ||
|
|
5fd6e316fa | ||
|
|
39a9d7765a | ||
|
|
7cfcba29a6 | ||
|
|
9bf8aadca9 | ||
|
|
714d4af63d | ||
|
|
8203fdb4f0 | ||
|
|
5e1e2d1a4f | ||
|
|
2f941de65b | ||
|
|
777c503002 | ||
|
|
e9b23f68fd | ||
|
|
efa45e6203 | ||
|
|
638f55f83c | ||
|
|
8b2fc29d5b | ||
|
|
b516fb0550 | ||
|
|
efef34c01e | ||
|
|
5f1dfa7599 | ||
|
|
8e9c7544cf | ||
|
|
4e3d5641c8 | ||
|
|
20b760529e | ||
|
|
a55a07c5ff | ||
|
|
94ee8ea297 | ||
|
|
ec5d71d0e1 | ||
|
|
d121d08d05 | ||
|
|
be08f4a558 | ||
|
|
010f082fbb | ||
|
|
073cdf6d51 | ||
|
|
4df8606ab6 | ||
|
|
71442d26ec | ||
|
|
4f5528869c | ||
|
|
f16feff17b | ||
|
|
71b233fe5f | ||
|
|
770dec9ed6 | ||
|
|
2ca95a988e | ||
|
|
d8aae538cd | ||
|
|
cf1e7ee08a | ||
|
|
d14513ddfd | ||
|
|
9a9017bc6c | ||
|
|
3c9b654713 | ||
|
|
80d2ad40bc | ||
|
|
31670e75e5 | ||
|
|
ed6011a2be | ||
|
|
cdded38ade | ||
|
|
f536f24833 | ||
|
|
f5bff00b1f | ||
|
|
27c9717445 | ||
|
|
863a1ba8ef | ||
|
|
cb04dd2b83 | ||
|
|
8c7cf51958 | ||
|
|
244fb1fed6 | ||
|
|
25f7a68a13 | ||
|
|
62d8cf79ef | ||
|
|
646b18d910 | ||
|
|
2f81b2e381 | ||
|
|
1f5a7e7885 | ||
|
|
80fca470f2 | ||
|
|
6e9d9ac856 | ||
|
|
8d6fada1eb | ||
|
|
3e715399a1 | ||
|
|
81cc8831f9 | ||
|
|
f7370044a7 | ||
|
|
51b015a629 | ||
|
|
392af7a553 | ||
|
|
d2dd07bad7 | ||
|
|
cebcd6925a | ||
|
|
e7b4357fc7 | ||
|
|
dc279dde4a | ||
|
|
c0810a674f | ||
|
|
0760cabbbe | ||
|
|
3b149c520b | ||
|
|
3d19fc89ff | ||
|
|
cd1b1919f4 | ||
|
|
0ed646eb27 | ||
|
|
c0c5859c99 | ||
|
|
a47121b849 | ||
|
|
d9dd20e89a | ||
|
|
ed4609ebe5 | ||
|
|
e24225c828 | ||
|
|
01ef86d658 | ||
|
|
cd4802da04 | ||
|
|
2aca65780f | ||
|
|
2c435f7387 | ||
|
|
cc1afd1a9c | ||
|
|
6f098cdba6 | ||
|
|
d03e9fb90a | ||
|
|
9f2966abe9 | ||
|
|
4e28ea1883 | ||
|
|
289214e85c | ||
|
|
a20d98bf93 | ||
|
|
50a296de20 | ||
|
|
c79e38e044 | ||
|
|
ccb95f803c | ||
|
|
dae745d925 | ||
|
|
791db65526 | ||
|
|
02e2e617f5 | ||
|
|
bfc8024119 | ||
|
|
f26cf6ed6f | ||
|
|
f2be55bd8e | ||
|
|
d241dd17ca | ||
|
|
cecafdfe6c | ||
|
|
6fecfd1a0e | ||
|
|
1ce95c473d | ||
|
|
eb365e398d |
80
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
80
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -1,40 +1,56 @@
|
|||||||
name: '🥳 发布插件'
|
name: 🥳 发布插件
|
||||||
title: "[Plugin] 插件名"
|
|
||||||
description: 提交插件到插件市场
|
description: 提交插件到插件市场
|
||||||
labels: [ "plugin-publish" ]
|
title: "[Plugin] 插件名"
|
||||||
|
labels: ["plugin-publish"]
|
||||||
|
assignees: []
|
||||||
body:
|
body:
|
||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: |
|
||||||
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
|
欢迎发布插件到插件市场!
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: 插件仓库
|
|
||||||
description: 插件的 GitHub 仓库链接
|
|
||||||
placeholder: >
|
|
||||||
如 https://github.com/Soulter/astrbot-github-cards
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: 描述
|
|
||||||
value: |
|
|
||||||
插件名:
|
|
||||||
插件作者:
|
|
||||||
插件简介:
|
|
||||||
支持的消息平台:(必填,如 QQ、微信、飞书)
|
|
||||||
标签:(可选)
|
|
||||||
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
|
|
||||||
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
|
|
||||||
|
|
||||||
- type: checkboxes
|
|
||||||
attributes:
|
|
||||||
label: Code of Conduct
|
|
||||||
options:
|
|
||||||
- label: >
|
|
||||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: "❤️"
|
value: |
|
||||||
|
## 插件基本信息
|
||||||
|
|
||||||
|
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||||
|
|
||||||
|
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: plugin-info
|
||||||
|
attributes:
|
||||||
|
label: 插件信息
|
||||||
|
description: 请在下方代码块中填写您的插件信息,确保反引号包裹了JSON
|
||||||
|
value: |
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "插件名",
|
||||||
|
"desc": "插件介绍",
|
||||||
|
"author": "作者名",
|
||||||
|
"repo": "插件仓库链接",
|
||||||
|
"tags": [],
|
||||||
|
"social_link": ""
|
||||||
|
}
|
||||||
|
```
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
## 检查
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: checks
|
||||||
|
attributes:
|
||||||
|
label: 插件检查清单
|
||||||
|
description: 请确认以下所有项目
|
||||||
|
options:
|
||||||
|
- label: 我的插件经过完整的测试
|
||||||
|
required: true
|
||||||
|
- label: 我的插件不包含恶意代码
|
||||||
|
required: true
|
||||||
|
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||||
|
required: true
|
||||||
|
|||||||
63
.github/copilot-instructions.md
vendored
Normal file
63
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# AstrBot Development Instructions
|
||||||
|
|
||||||
|
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
|
||||||
|
|
||||||
|
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
|
||||||
|
|
||||||
|
## Working Effectively
|
||||||
|
|
||||||
|
### Bootstrap and Install Dependencies
|
||||||
|
- **Python 3.10+ required** - Check `.python-version` file
|
||||||
|
- Install UV package manager: `pip install uv`
|
||||||
|
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
|
||||||
|
- Create required directories: `mkdir -p data/plugins data/config data/temp`
|
||||||
|
|
||||||
|
### Running the Application
|
||||||
|
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||||
|
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||||
|
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||||
|
|
||||||
|
### Dashboard Build (Vue.js/Node.js)
|
||||||
|
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||||
|
- Navigate to dashboard: `cd dashboard`
|
||||||
|
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
|
||||||
|
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
|
||||||
|
- Dashboard creates optimized production build in `dashboard/dist/`
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- Do not generate test files for now.
|
||||||
|
|
||||||
|
### Code Quality and Linting
|
||||||
|
- Install ruff linter: `uv add --dev ruff`
|
||||||
|
- Check code style: `uv run ruff check .` -- takes <1 second
|
||||||
|
- Check formatting: `uv run ruff format --check .` -- takes <1 second
|
||||||
|
- Fix formatting: `uv run ruff format .`
|
||||||
|
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||||
|
|
||||||
|
### Plugin Development
|
||||||
|
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||||
|
- Plugin system supports function tools and message handlers
|
||||||
|
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||||
|
|
||||||
|
### Common Issues and Workarounds
|
||||||
|
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
|
||||||
|
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
|
||||||
|
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
|
||||||
|
|
||||||
|
## CI/CD Integration
|
||||||
|
- GitHub Actions workflows in `.github/workflows/`
|
||||||
|
- Docker builds supported via `Dockerfile`
|
||||||
|
- Pre-commit hooks enforce ruff formatting and linting
|
||||||
|
|
||||||
|
## Docker Support
|
||||||
|
- Primary deployment method: `docker run soulter/astrbot:latest`
|
||||||
|
- Compose file available: `compose.yml`
|
||||||
|
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
|
||||||
|
- Volume mount required: `./data:/AstrBot/data`
|
||||||
|
|
||||||
|
## Multi-language Support
|
||||||
|
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
|
||||||
|
- UI supports internationalization
|
||||||
|
- Default language is Chinese
|
||||||
|
|
||||||
|
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
|
||||||
13
.github/dependabot.yml
vendored
Normal file
13
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
||||||
|
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
||||||
|
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: github-actions
|
||||||
|
directory: /
|
||||||
|
groups:
|
||||||
|
github-actions:
|
||||||
|
patterns:
|
||||||
|
- "*" # Group all Actions updates into a single larger pull request
|
||||||
|
schedule:
|
||||||
|
interval: weekly
|
||||||
6
.github/workflows/auto_release.yml
vendored
6
.github/workflows/auto_release.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
|||||||
contents: write
|
contents: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Dashboard Build
|
- name: Dashboard Build
|
||||||
run: |
|
run: |
|
||||||
@@ -70,10 +70,10 @@ jobs:
|
|||||||
needs: build-and-publish-to-github-release
|
needs: build-and-publish-to-github-release
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
|||||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
|
|||||||
20
.github/workflows/coverage_test.yml
vendored
20
.github/workflows/coverage_test.yml
vendored
@@ -8,6 +8,7 @@ on:
|
|||||||
- 'README.md'
|
- 'README.md'
|
||||||
- 'changelogs/**'
|
- 'changelogs/**'
|
||||||
- 'dashboard/**'
|
- 'dashboard/**'
|
||||||
|
pull_request:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -16,30 +17,29 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install pytest pytest-asyncio pytest-cov
|
||||||
pip install pytest pytest-cov pytest-asyncio
|
pip install --editable .
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
mkdir data
|
mkdir -p data/plugins
|
||||||
mkdir data/plugins
|
mkdir -p data/config
|
||||||
mkdir data/config
|
mkdir -p data/temp
|
||||||
mkdir data/temp
|
|
||||||
export TESTING=true
|
export TESTING=true
|
||||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
|
||||||
|
|
||||||
- name: Upload results to Codecov
|
- name: Upload results to Codecov
|
||||||
uses: codecov/codecov-action@v4
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
2
.github/workflows/dashboard_ci.yml
vendored
2
.github/workflows/dashboard_ci.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: npm install, build
|
- name: npm install, build
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
.github/workflows/docker-image.yml
vendored
2
.github/workflows/docker-image.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Pull The Codes
|
- name: Pull The Codes
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v5
|
- uses: actions/stale@v9
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: 'Stale issue message'
|
stale-issue-message: 'Stale issue message'
|
||||||
|
|||||||
95
README.md
95
README.md
@@ -16,7 +16,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||

|

|
||||||

|

|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
@@ -27,57 +27,50 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
|
|
||||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||||
|
|
||||||
|
|
||||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
|
||||||
-->
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
>
|
|
||||||
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
|
|
||||||
|
|
||||||
## ✨ 近期更新
|
|
||||||
|
|
||||||
<details><summary>1. AstrBot 现已自带知识库能力</summary>
|
|
||||||
|
|
||||||
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
|
||||||
|
|
||||||
## ✨ 主要功能
|
## ✨ 主要功能
|
||||||
|
|
||||||
> [!NOTE]
|
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||||
|
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||||
2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
|
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
|
||||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
|
||||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
|
||||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
|
||||||
>
|
|
||||||
> 用户名: `astrbot`, 密码: `astrbot`。
|
|
||||||
|
|
||||||
## ✨ 使用方式
|
## ✨ 使用方式
|
||||||
|
|
||||||
#### Docker 部署
|
#### Docker 部署
|
||||||
|
|
||||||
|
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||||
|
|
||||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||||
|
|
||||||
|
#### 宝塔面板部署
|
||||||
|
|
||||||
|
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||||
|
|
||||||
|
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||||
|
|
||||||
|
#### 1Panel 部署
|
||||||
|
|
||||||
|
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||||
|
|
||||||
|
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
|
||||||
|
|
||||||
|
#### 在 雨云 上部署
|
||||||
|
|
||||||
|
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
||||||
|
|
||||||
|
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||||
|
|
||||||
|
#### 在 Replit 上部署
|
||||||
|
|
||||||
|
社区贡献的部署方式。
|
||||||
|
|
||||||
|
[](https://repl.it/github/Soulter/AstrBot)
|
||||||
|
|
||||||
#### Windows 一键安装器部署
|
#### Windows 一键安装器部署
|
||||||
|
|
||||||
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||||
|
|
||||||
#### 宝塔面板部署
|
|
||||||
|
|
||||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
|
||||||
|
|
||||||
#### CasaOS 部署
|
#### CasaOS 部署
|
||||||
|
|
||||||
社区贡献的部署方式。
|
社区贡献的部署方式。
|
||||||
@@ -101,27 +94,14 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
|||||||
uv run main.py
|
uv run main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
或者,直接通过 uvx 安装 AstrBot:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir astrbot && cd astrbot
|
|
||||||
uvx astrbot init
|
|
||||||
# uvx astrbot run
|
|
||||||
```
|
|
||||||
|
|
||||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||||
|
|
||||||
#### Replit 部署
|
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
| 平台 | 支持性 |
|
||||||
| -------- | ------- |
|
| -------- | ------- |
|
||||||
| QQ(官方机器人接口) | ✔ |
|
| QQ(官方机器人接口) | ✔ |
|
||||||
| QQ(OneBot) | ✔ |
|
| QQ(OneBot) | ✔ |
|
||||||
| 微信个人号 | ✔ |
|
|
||||||
| Telegram | ✔ |
|
| Telegram | ✔ |
|
||||||
| 企业微信 | ✔ |
|
| 企业微信 | ✔ |
|
||||||
| 微信客服 | ✔ |
|
| 微信客服 | ✔ |
|
||||||
@@ -140,7 +120,7 @@ uvx astrbot init
|
|||||||
|
|
||||||
| 名称 | 支持性 | 类型 | 备注 |
|
| 名称 | 支持性 | 类型 | 备注 |
|
||||||
| -------- | ------- | ------- | ------- |
|
| -------- | ------- | ------- | ------- |
|
||||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||||
| Claude API | ✔ | 文本生成 | |
|
| Claude API | ✔ | 文本生成 | |
|
||||||
| Google Gemini API | ✔ | 文本生成 | |
|
| Google Gemini API | ✔ | 文本生成 | |
|
||||||
| Dify | ✔ | LLMOps | |
|
| Dify | ✔ | LLMOps | |
|
||||||
@@ -148,6 +128,8 @@ uvx astrbot init
|
|||||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||||
|
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||||
|
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||||
| OneAPI | ✔ | LLM 分发系统 | |
|
| OneAPI | ✔ | LLM 分发系统 | |
|
||||||
@@ -223,7 +205,7 @@ _✨ WebUI ✨_
|
|||||||
|
|
||||||
此外,本项目的诞生离不开以下开源项目:
|
此外,本项目的诞生离不开以下开源项目:
|
||||||
|
|
||||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
|
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
@@ -237,11 +219,8 @@ _✨ WebUI ✨_
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Disclaimer
|

|
||||||
|
|
||||||
1. The project is protected under the `AGPL-v3` opensource license.
|
|
||||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
|
||||||
3. Please ensure compliance with local laws and regulations when using this project.
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
|||||||
## ✨ 主な機能
|
## ✨ 主な機能
|
||||||
|
|
||||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||||
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
|||||||
## 免責事項
|
## 免責事項
|
||||||
|
|
||||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||||
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||||
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
|
||||||
|
|
||||||
<!-- ## ✨ ATRI [ベータテスト]
|
<!-- ## ✨ ATRI [ベータテスト]
|
||||||
|
|
||||||
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
|||||||
4. TTS
|
4. TTS
|
||||||
-->
|
-->
|
||||||
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
0
astrbot.lock
Normal file
0
astrbot.lock
Normal file
@@ -1 +1 @@
|
|||||||
__version__ = "3.5.8"
|
__version__ = "3.5.23"
|
||||||
|
|||||||
@@ -139,6 +139,14 @@ def conf():
|
|||||||
- dashboard.password: Dashboard 密码
|
- dashboard.password: Dashboard 密码
|
||||||
|
|
||||||
- callback_api_base: 回调接口基址
|
- callback_api_base: 回调接口基址
|
||||||
|
|
||||||
|
可用子命令:
|
||||||
|
|
||||||
|
- set: 设置配置项值
|
||||||
|
|
||||||
|
- get: 获取配置项值
|
||||||
|
|
||||||
|
- login-info: 显示 Web 管理面板登录信息
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -204,3 +212,44 @@ def get_config(key: str = None):
|
|||||||
click.echo(f" {key}: {value}")
|
click.echo(f" {key}: {value}")
|
||||||
except (KeyError, TypeError):
|
except (KeyError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@conf.command(name="login-info")
|
||||||
|
def get_login_info():
|
||||||
|
"""显示 Web 管理面板的登录信息
|
||||||
|
|
||||||
|
在 Docker 环境中使用示例:
|
||||||
|
docker exec -e ASTRBOT_ROOT=/AstrBot astrbot-container astrbot conf login-info
|
||||||
|
"""
|
||||||
|
config = _load_config()
|
||||||
|
|
||||||
|
try:
|
||||||
|
username = _get_nested_item(config, "dashboard.username")
|
||||||
|
# 注意:我们不显示实际的MD5哈希密码,而是提示用户如何重置
|
||||||
|
click.echo("🔐 Web 管理面板登录信息:")
|
||||||
|
click.echo(f" 用户名: {username}")
|
||||||
|
click.echo(" 密码: [已加密存储]")
|
||||||
|
click.echo()
|
||||||
|
click.echo("💡 如需重置密码,请使用以下命令:")
|
||||||
|
click.echo(" astrbot conf set dashboard.password <新密码>")
|
||||||
|
click.echo()
|
||||||
|
click.echo("🌐 访问地址:")
|
||||||
|
|
||||||
|
# 尝试获取端口信息
|
||||||
|
try:
|
||||||
|
port = _get_nested_item(config, "dashboard.port")
|
||||||
|
click.echo(f" http://localhost:{port}")
|
||||||
|
click.echo(f" http://your-server-ip:{port}")
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
click.echo(" http://localhost:6185 (默认端口)")
|
||||||
|
click.echo(" http://your-server-ip:6185 (默认端口)")
|
||||||
|
|
||||||
|
click.echo()
|
||||||
|
click.echo("📋 Docker 环境使用说明:")
|
||||||
|
click.echo(" 如果在 Docker 中运行,请使用以下命令格式:")
|
||||||
|
click.echo(" docker exec -e ASTRBOT_ROOT=/AstrBot <容器名> astrbot conf login-info")
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
click.echo("❌ 无法找到登录配置,请先运行 'astrbot init' 初始化")
|
||||||
|
except Exception as e:
|
||||||
|
raise click.UsageError(f"获取登录信息失败: {str(e)}")
|
||||||
|
|||||||
@@ -16,7 +16,13 @@ def check_astrbot_root(path: str | Path) -> bool:
|
|||||||
|
|
||||||
def get_astrbot_root() -> Path:
|
def get_astrbot_root() -> Path:
|
||||||
"""获取Astrbot根目录路径"""
|
"""获取Astrbot根目录路径"""
|
||||||
return Path.cwd()
|
import os
|
||||||
|
|
||||||
|
# 使用与core应用相同的路径解析逻辑,优先使用ASTRBOT_ROOT环境变量
|
||||||
|
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||||
|
return Path(path)
|
||||||
|
else:
|
||||||
|
return Path.cwd()
|
||||||
|
|
||||||
|
|
||||||
async def check_dashboard(astrbot_root: Path) -> None:
|
async def check_dashboard(astrbot_root: Path) -> None:
|
||||||
|
|||||||
@@ -117,19 +117,24 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
# 从 metadata.yaml 加载元数据
|
# 从 metadata.yaml 加载元数据
|
||||||
metadata = load_yaml_metadata(plugin_dir)
|
metadata = load_yaml_metadata(plugin_dir)
|
||||||
|
|
||||||
|
if "desc" not in metadata and "description" in metadata:
|
||||||
|
metadata["desc"] = metadata["description"]
|
||||||
|
|
||||||
# 如果成功加载元数据,添加到结果列表
|
# 如果成功加载元数据,添加到结果列表
|
||||||
if metadata and all(
|
if metadata and all(
|
||||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||||
):
|
):
|
||||||
result.append({
|
result.append(
|
||||||
"name": str(metadata.get("name", "")),
|
{
|
||||||
"desc": str(metadata.get("desc", "")),
|
"name": str(metadata.get("name", "")),
|
||||||
"version": str(metadata.get("version", "")),
|
"desc": str(metadata.get("desc", "")),
|
||||||
"author": str(metadata.get("author", "")),
|
"version": str(metadata.get("version", "")),
|
||||||
"repo": str(metadata.get("repo", "")),
|
"author": str(metadata.get("author", "")),
|
||||||
"status": PluginStatus.INSTALLED,
|
"repo": str(metadata.get("repo", "")),
|
||||||
"local_path": str(plugin_dir),
|
"status": PluginStatus.INSTALLED,
|
||||||
})
|
"local_path": str(plugin_dir),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 获取在线插件列表
|
# 获取在线插件列表
|
||||||
online_plugins = []
|
online_plugins = []
|
||||||
@@ -139,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
for plugin_id, plugin_info in data.items():
|
for plugin_id, plugin_info in data.items():
|
||||||
online_plugins.append({
|
online_plugins.append(
|
||||||
"name": str(plugin_id),
|
{
|
||||||
"desc": str(plugin_info.get("desc", "")),
|
"name": str(plugin_id),
|
||||||
"version": str(plugin_info.get("version", "")),
|
"desc": str(plugin_info.get("desc", "")),
|
||||||
"author": str(plugin_info.get("author", "")),
|
"version": str(plugin_info.get("version", "")),
|
||||||
"repo": str(plugin_info.get("repo", "")),
|
"author": str(plugin_info.get("author", "")),
|
||||||
"status": PluginStatus.NOT_INSTALLED,
|
"repo": str(plugin_info.get("repo", "")),
|
||||||
"local_path": None,
|
"status": PluginStatus.NOT_INSTALLED,
|
||||||
})
|
"local_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from .utils.astrbot_path import get_astrbot_data_path
|
|||||||
# 初始化数据存储文件夹
|
# 初始化数据存储文件夹
|
||||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||||
|
|
||||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
|
||||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||||
|
|
||||||
astrbot_config = AstrBotConfig()
|
astrbot_config = AstrBotConfig()
|
||||||
@@ -29,6 +28,3 @@ pip_installer = PipInstaller(
|
|||||||
astrbot_config.get("pip_install_arg", ""),
|
astrbot_config.get("pip_install_arg", ""),
|
||||||
astrbot_config.get("pypi_index_url", None),
|
astrbot_config.get("pypi_index_url", None),
|
||||||
)
|
)
|
||||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
|
||||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,17 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "3.5.17"
|
VERSION = "3.5.24"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"config_version": 2,
|
"config_version": 2,
|
||||||
"platform_settings": {
|
"platform_settings": {
|
||||||
|
"plugin_enable": {},
|
||||||
"unique_session": False,
|
"unique_session": False,
|
||||||
"rate_limit": {
|
"rate_limit": {
|
||||||
"time": 60,
|
"time": 60,
|
||||||
@@ -52,6 +54,7 @@ DEFAULT_CONFIG = {
|
|||||||
"wake_prefix": "",
|
"wake_prefix": "",
|
||||||
"web_search": False,
|
"web_search": False,
|
||||||
"web_search_link": False,
|
"web_search_link": False,
|
||||||
|
"display_reasoning_text": False,
|
||||||
"identifier": False,
|
"identifier": False,
|
||||||
"datetime_system_prompt": True,
|
"datetime_system_prompt": True,
|
||||||
"default_personality": "default",
|
"default_personality": "default",
|
||||||
@@ -59,8 +62,10 @@ DEFAULT_CONFIG = {
|
|||||||
"max_context_length": -1,
|
"max_context_length": -1,
|
||||||
"dequeue_context_length": 1,
|
"dequeue_context_length": 1,
|
||||||
"streaming_response": False,
|
"streaming_response": False,
|
||||||
|
"show_tool_use_status": False,
|
||||||
"streaming_segmented": False,
|
"streaming_segmented": False,
|
||||||
"separate_provider": False,
|
"separate_provider": True,
|
||||||
|
"max_agent_step": 30,
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -102,6 +107,7 @@ DEFAULT_CONFIG = {
|
|||||||
"enable": True,
|
"enable": True,
|
||||||
"username": "astrbot",
|
"username": "astrbot",
|
||||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||||
|
"jwt_secret": "",
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": 6185,
|
"port": 6185,
|
||||||
},
|
},
|
||||||
@@ -152,15 +158,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"ws_reverse_port": 6199,
|
"ws_reverse_port": 6199,
|
||||||
"ws_reverse_token": "",
|
"ws_reverse_token": "",
|
||||||
},
|
},
|
||||||
"微信个人号(Gewechat)": {
|
|
||||||
"id": "gwchat",
|
|
||||||
"type": "gewechat",
|
|
||||||
"enable": False,
|
|
||||||
"base_url": "http://localhost:2531",
|
|
||||||
"nickname": "soulter",
|
|
||||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
|
||||||
"port": 11451,
|
|
||||||
},
|
|
||||||
"微信个人号(WeChatPadPro)": {
|
"微信个人号(WeChatPadPro)": {
|
||||||
"id": "wechatpadpro",
|
"id": "wechatpadpro",
|
||||||
"type": "wechatpadpro",
|
"type": "wechatpadpro",
|
||||||
@@ -313,8 +310,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"id": {
|
"id": {
|
||||||
"description": "机器人名称",
|
"description": "机器人名称",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
"hint": "机器人名称",
|
||||||
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
|
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"description": "适配器类型",
|
"description": "适配器类型",
|
||||||
@@ -365,17 +361,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "飞书机器人的名字",
|
"description": "飞书机器人的名字",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"discord_token":{
|
"discord_token": {
|
||||||
"description": "Discord Bot Token",
|
"description": "Discord Bot Token",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "在此处填入你的Discord Bot Token"
|
"hint": "在此处填入你的Discord Bot Token",
|
||||||
},
|
},
|
||||||
"discord_proxy":{
|
"discord_proxy": {
|
||||||
"description": "Discord 代理地址",
|
"description": "Discord 代理地址",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "可选的代理地址:http://ip:port"
|
"hint": "可选的代理地址:http://ip:port",
|
||||||
},
|
},
|
||||||
"discord_command_register": {
|
"discord_command_register": {
|
||||||
"description": "是否自动将插件指令注册为 Discord 斜杠指令",
|
"description": "是否自动将插件指令注册为 Discord 斜杠指令",
|
||||||
@@ -386,10 +381,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
|
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
|
||||||
},
|
},
|
||||||
"discord_guild_id_for_debug": {
|
|
||||||
"description": "【开发用】指定一个服务器(Guild)ID。在此服务器注册的指令会立刻生效,便于调试。留空则注册为全局指令。",
|
|
||||||
"type": "string",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"platform_settings": {
|
"platform_settings": {
|
||||||
@@ -442,7 +433,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"ignore_bot_self_message": {
|
"ignore_bot_self_message": {
|
||||||
"description": "是否忽略机器人自身的消息",
|
"description": "是否忽略机器人自身的消息",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
"hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||||
},
|
},
|
||||||
"ignore_at_all": {
|
"ignore_at_all": {
|
||||||
"description": "是否忽略 @ 全体成员",
|
"description": "是否忽略 @ 全体成员",
|
||||||
@@ -485,13 +476,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"regex": {
|
"regex": {
|
||||||
"description": "正则表达式",
|
"description": "正则表达式",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||||
},
|
},
|
||||||
"content_cleanup_rule": {
|
"content_cleanup_rule": {
|
||||||
"description": "过滤分段后的内容",
|
"description": "过滤分段后的内容",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
|
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -514,7 +503,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "ID 白名单",
|
"description": "ID 白名单",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
|
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
|
||||||
},
|
},
|
||||||
"id_whitelist_log": {
|
"id_whitelist_log": {
|
||||||
@@ -544,7 +532,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "路径映射",
|
"description": "路径映射",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -604,18 +591,19 @@ CONFIG_METADATA_2 = {
|
|||||||
"config_template": {
|
"config_template": {
|
||||||
"OpenAI": {
|
"OpenAI": {
|
||||||
"id": "openai",
|
"id": "openai",
|
||||||
|
"provider": "openai",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://api.openai.com/v1",
|
"api_base": "https://api.openai.com/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||||
"model": "gpt-4o-mini",
|
"hint": "也兼容所有与OpenAI API兼容的服务。",
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"Azure OpenAI": {
|
"Azure OpenAI": {
|
||||||
"id": "azure",
|
"id": "azure",
|
||||||
|
"provider": "azure",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -623,24 +611,23 @@ CONFIG_METADATA_2 = {
|
|||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "",
|
"api_base": "",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||||
"model": "gpt-4o-mini",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"xAI": {
|
"xAI": {
|
||||||
"id": "xai",
|
"id": "xai",
|
||||||
|
"provider": "xai",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://api.x.ai/v1",
|
"api_base": "https://api.x.ai/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||||
"model": "grok-2-latest",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"Anthropic": {
|
"Anthropic": {
|
||||||
|
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
|
||||||
"id": "claude",
|
"id": "claude",
|
||||||
|
"provider": "anthropic",
|
||||||
"type": "anthropic_chat_completion",
|
"type": "anthropic_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -650,21 +637,23 @@ CONFIG_METADATA_2 = {
|
|||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "claude-3-5-sonnet-latest",
|
"model": "claude-3-5-sonnet-latest",
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"Ollama": {
|
"Ollama": {
|
||||||
|
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
|
||||||
"id": "ollama_default",
|
"id": "ollama_default",
|
||||||
|
"provider": "ollama",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||||
"api_base": "http://localhost:11434/v1",
|
"api_base": "http://localhost:11434/v1",
|
||||||
"model_config": {
|
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||||
"model": "llama3.1-8b",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"LM Studio": {
|
"LM Studio": {
|
||||||
"id": "lm_studio",
|
"id": "lm_studio",
|
||||||
|
"provider": "lm_studio",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -676,6 +665,7 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
"Gemini(OpenAI兼容)": {
|
"Gemini(OpenAI兼容)": {
|
||||||
"id": "gemini_default",
|
"id": "gemini_default",
|
||||||
|
"provider": "google",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -684,10 +674,12 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-1.5-flash",
|
||||||
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"Gemini": {
|
"Gemini": {
|
||||||
"id": "gemini_default",
|
"id": "gemini_default",
|
||||||
|
"provider": "google",
|
||||||
"type": "googlegenai_chat_completion",
|
"type": "googlegenai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -696,6 +688,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-2.0-flash-exp",
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"gm_resp_image_modal": False,
|
"gm_resp_image_modal": False,
|
||||||
"gm_native_search": False,
|
"gm_native_search": False,
|
||||||
@@ -713,18 +706,81 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
"DeepSeek": {
|
"DeepSeek": {
|
||||||
"id": "deepseek_default",
|
"id": "deepseek_default",
|
||||||
|
"provider": "deepseek",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://api.deepseek.com/v1",
|
"api_base": "https://api.deepseek.com/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
|
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||||
|
},
|
||||||
|
"302.AI": {
|
||||||
|
"id": "302ai",
|
||||||
|
"provider": "302ai",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"api_base": "https://api.302.ai/v1",
|
||||||
|
"timeout": 120,
|
||||||
|
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||||
|
},
|
||||||
|
"硅基流动": {
|
||||||
|
"id": "siliconflow",
|
||||||
|
"provider": "siliconflow",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"timeout": 120,
|
||||||
|
"api_base": "https://api.siliconflow.cn/v1",
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "deepseek-chat",
|
"model": "deepseek-ai/DeepSeek-V3",
|
||||||
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"PPIO派欧云": {
|
||||||
|
"id": "ppio",
|
||||||
|
"provider": "ppio",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||||
|
"timeout": 120,
|
||||||
|
"model_config": {
|
||||||
|
"model": "deepseek/deepseek-r1",
|
||||||
|
"temperature": 0.4,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"优云智算": {
|
||||||
|
"id": "compshare",
|
||||||
|
"provider": "compshare",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"api_base": "https://api.modelverse.cn/v1",
|
||||||
|
"timeout": 120,
|
||||||
|
"model_config": {
|
||||||
|
"model": "moonshotai/Kimi-K2-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Kimi": {
|
||||||
|
"id": "moonshot",
|
||||||
|
"provider": "moonshot",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"timeout": 120,
|
||||||
|
"api_base": "https://api.moonshot.cn/v1",
|
||||||
|
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||||
|
},
|
||||||
"智谱 AI": {
|
"智谱 AI": {
|
||||||
"id": "zhipu_default",
|
"id": "zhipu_default",
|
||||||
|
"provider": "zhipu",
|
||||||
"type": "zhipu_chat_completion",
|
"type": "zhipu_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -735,55 +791,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "glm-4-flash",
|
"model": "glm-4-flash",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"硅基流动": {
|
|
||||||
"id": "siliconflow",
|
|
||||||
"type": "openai_chat_completion",
|
|
||||||
"provider_type": "chat_completion",
|
|
||||||
"enable": True,
|
|
||||||
"key": [],
|
|
||||||
"timeout": 120,
|
|
||||||
"api_base": "https://api.siliconflow.cn/v1",
|
|
||||||
"model_config": {
|
|
||||||
"model": "deepseek-ai/DeepSeek-V3",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"Kimi": {
|
|
||||||
"id": "moonshot",
|
|
||||||
"type": "openai_chat_completion",
|
|
||||||
"provider_type": "chat_completion",
|
|
||||||
"enable": True,
|
|
||||||
"key": [],
|
|
||||||
"timeout": 120,
|
|
||||||
"api_base": "https://api.moonshot.cn/v1",
|
|
||||||
"model_config": {
|
|
||||||
"model": "moonshot-v1-8k",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"PPIO派欧云": {
|
|
||||||
"id": "ppio",
|
|
||||||
"type": "openai_chat_completion",
|
|
||||||
"provider_type": "chat_completion",
|
|
||||||
"enable": True,
|
|
||||||
"key": [],
|
|
||||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
|
||||||
"timeout": 120,
|
|
||||||
"model_config": {
|
|
||||||
"model": "deepseek/deepseek-r1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"LLMTuner": {
|
|
||||||
"id": "llmtuner_default",
|
|
||||||
"type": "llm_tuner",
|
|
||||||
"provider_type": "chat_completion",
|
|
||||||
"enable": True,
|
|
||||||
"base_model_path": "",
|
|
||||||
"adapter_model_path": "",
|
|
||||||
"llmtuner_template": "",
|
|
||||||
"finetuning_type": "lora",
|
|
||||||
"quantization_bit": 4,
|
|
||||||
},
|
|
||||||
"Dify": {
|
"Dify": {
|
||||||
"id": "dify_app_default",
|
"id": "dify_app_default",
|
||||||
|
"provider": "dify",
|
||||||
"type": "dify",
|
"type": "dify",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -794,9 +804,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"dify_query_input_key": "astrbot_text_query",
|
"dify_query_input_key": "astrbot_text_query",
|
||||||
"variables": {},
|
"variables": {},
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
|
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||||
},
|
},
|
||||||
"阿里云百炼应用": {
|
"阿里云百炼应用": {
|
||||||
"id": "dashscope",
|
"id": "dashscope",
|
||||||
|
"provider": "dashscope",
|
||||||
"type": "dashscope",
|
"type": "dashscope",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -811,8 +823,20 @@ CONFIG_METADATA_2 = {
|
|||||||
"variables": {},
|
"variables": {},
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
},
|
},
|
||||||
|
"ModelScope": {
|
||||||
|
"id": "modelscope",
|
||||||
|
"provider": "modelscope",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"provider_type": "chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"timeout": 120,
|
||||||
|
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||||
|
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||||
|
},
|
||||||
"FastGPT": {
|
"FastGPT": {
|
||||||
"id": "fastgpt",
|
"id": "fastgpt",
|
||||||
|
"provider": "fastgpt",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"provider_type": "chat_completion",
|
"provider_type": "chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -822,6 +846,7 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
"Whisper(API)": {
|
"Whisper(API)": {
|
||||||
"id": "whisper",
|
"id": "whisper",
|
||||||
|
"provider": "openai",
|
||||||
"type": "openai_whisper_api",
|
"type": "openai_whisper_api",
|
||||||
"provider_type": "speech_to_text",
|
"provider_type": "speech_to_text",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -830,16 +855,18 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "whisper-1",
|
"model": "whisper-1",
|
||||||
},
|
},
|
||||||
"Whisper(本地加载)": {
|
"Whisper(本地加载)": {
|
||||||
"whisper_hint": "(不用修改我)",
|
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||||
|
"provider": "openai",
|
||||||
"type": "openai_whisper_selfhost",
|
"type": "openai_whisper_selfhost",
|
||||||
"provider_type": "speech_to_text",
|
"provider_type": "speech_to_text",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"id": "whisper",
|
"id": "whisper_selfhost",
|
||||||
"model": "tiny",
|
"model": "tiny",
|
||||||
},
|
},
|
||||||
"SenseVoice(本地加载)": {
|
"SenseVoice(本地加载)": {
|
||||||
"sensevoice_hint": "(不用修改我)",
|
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||||
"type": "sensevoice_stt_selfhost",
|
"type": "sensevoice_stt_selfhost",
|
||||||
|
"provider": "sensevoice",
|
||||||
"provider_type": "speech_to_text",
|
"provider_type": "speech_to_text",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"id": "sensevoice",
|
"id": "sensevoice",
|
||||||
@@ -849,6 +876,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"OpenAI TTS(API)": {
|
"OpenAI TTS(API)": {
|
||||||
"id": "openai_tts",
|
"id": "openai_tts",
|
||||||
"type": "openai_tts_api",
|
"type": "openai_tts_api",
|
||||||
|
"provider": "openai",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
@@ -858,8 +886,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": "20",
|
"timeout": "20",
|
||||||
},
|
},
|
||||||
"Edge TTS": {
|
"Edge TTS": {
|
||||||
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||||
"id": "edge_tts",
|
"id": "edge_tts",
|
||||||
|
"provider": "microsoft",
|
||||||
"type": "edge_tts",
|
"type": "edge_tts",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -869,6 +898,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"GSV TTS(本地加载)": {
|
"GSV TTS(本地加载)": {
|
||||||
"id": "gsv_tts",
|
"id": "gsv_tts",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
|
"provider": "gpt_sovits",
|
||||||
"type": "gsv_tts_selfhost",
|
"type": "gsv_tts_selfhost",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"api_base": "http://127.0.0.1:9880",
|
"api_base": "http://127.0.0.1:9880",
|
||||||
@@ -900,6 +930,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"GSVI TTS(API)": {
|
"GSVI TTS(API)": {
|
||||||
"id": "gsvi_tts",
|
"id": "gsvi_tts",
|
||||||
"type": "gsvi_tts_api",
|
"type": "gsvi_tts_api",
|
||||||
|
"provider": "gpt_sovits_inference",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"api_base": "http://127.0.0.1:5000",
|
"api_base": "http://127.0.0.1:5000",
|
||||||
"character": "",
|
"character": "",
|
||||||
@@ -909,6 +940,7 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
"FishAudio TTS(API)": {
|
"FishAudio TTS(API)": {
|
||||||
"id": "fishaudio_tts",
|
"id": "fishaudio_tts",
|
||||||
|
"provider": "fishaudio",
|
||||||
"type": "fishaudio_tts_api",
|
"type": "fishaudio_tts_api",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -919,6 +951,7 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
"阿里云百炼 TTS(API)": {
|
"阿里云百炼 TTS(API)": {
|
||||||
"id": "dashscope_tts",
|
"id": "dashscope_tts",
|
||||||
|
"provider": "dashscope",
|
||||||
"type": "dashscope_tts",
|
"type": "dashscope_tts",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -930,6 +963,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"Azure TTS": {
|
"Azure TTS": {
|
||||||
"id": "azure_tts",
|
"id": "azure_tts",
|
||||||
"type": "azure_tts",
|
"type": "azure_tts",
|
||||||
|
"provider": "azure",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"azure_tts_voice": "zh-CN-YunxiaNeural",
|
"azure_tts_voice": "zh-CN-YunxiaNeural",
|
||||||
@@ -943,6 +977,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"MiniMax TTS(API)": {
|
"MiniMax TTS(API)": {
|
||||||
"id": "minimax_tts",
|
"id": "minimax_tts",
|
||||||
"type": "minimax_tts_api",
|
"type": "minimax_tts_api",
|
||||||
|
"provider": "minimax",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
@@ -964,6 +999,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"火山引擎_TTS(API)": {
|
"火山引擎_TTS(API)": {
|
||||||
"id": "volcengine_tts",
|
"id": "volcengine_tts",
|
||||||
"type": "volcengine_tts",
|
"type": "volcengine_tts",
|
||||||
|
"provider": "volcengine",
|
||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
@@ -974,20 +1010,35 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||||
"timeout": 20,
|
"timeout": 20,
|
||||||
},
|
},
|
||||||
|
"Gemini TTS": {
|
||||||
|
"id": "gemini_tts",
|
||||||
|
"type": "gemini_tts",
|
||||||
|
"provider": "google",
|
||||||
|
"provider_type": "text_to_speech",
|
||||||
|
"enable": False,
|
||||||
|
"gemini_tts_api_key": "",
|
||||||
|
"gemini_tts_api_base": "",
|
||||||
|
"gemini_tts_timeout": 20,
|
||||||
|
"gemini_tts_model": "gemini-2.5-flash-preview-tts",
|
||||||
|
"gemini_tts_prefix": "",
|
||||||
|
"gemini_tts_voice_name": "Leda",
|
||||||
|
},
|
||||||
"OpenAI Embedding": {
|
"OpenAI Embedding": {
|
||||||
"id": "openai_embedding",
|
"id": "openai_embedding",
|
||||||
"type": "openai_embedding",
|
"type": "openai_embedding",
|
||||||
|
"provider": "openai",
|
||||||
"provider_type": "embedding",
|
"provider_type": "embedding",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"embedding_api_key": "",
|
"embedding_api_key": "",
|
||||||
"embedding_api_base": "",
|
"embedding_api_base": "",
|
||||||
"embedding_model": "",
|
"embedding_model": "",
|
||||||
"embedding_dimensions": 1536,
|
"embedding_dimensions": 1024,
|
||||||
"timeout": 20,
|
"timeout": 20,
|
||||||
},
|
},
|
||||||
"Gemini Embedding": {
|
"Gemini Embedding": {
|
||||||
"id": "gemini_embedding",
|
"id": "gemini_embedding",
|
||||||
"type": "gemini_embedding",
|
"type": "gemini_embedding",
|
||||||
|
"provider": "google",
|
||||||
"provider_type": "embedding",
|
"provider_type": "embedding",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"embedding_api_key": "",
|
"embedding_api_key": "",
|
||||||
@@ -998,17 +1049,19 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
"provider": {
|
||||||
|
"type": "string",
|
||||||
|
"invisible": True,
|
||||||
|
},
|
||||||
"gpt_weights_path": {
|
"gpt_weights_path": {
|
||||||
"description": "GPT模型文件路径",
|
"description": "GPT模型文件路径",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"sovits_weights_path": {
|
"sovits_weights_path": {
|
||||||
"description": "SoVITS模型文件路径",
|
"description": "SoVITS模型文件路径",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"gsv_default_parms": {
|
"gsv_default_parms": {
|
||||||
"description": "GPT_SoVITS默认参数",
|
"description": "GPT_SoVITS默认参数",
|
||||||
@@ -1019,13 +1072,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "参考音频文件路径",
|
"description": "参考音频文件路径",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
|
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"gsv_prompt_text": {
|
"gsv_prompt_text": {
|
||||||
"description": "参考音频文本",
|
"description": "参考音频文本",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "必填!请填写参考音频讲述的文本",
|
"hint": "必填!请填写参考音频讲述的文本",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"gsv_prompt_lang": {
|
"gsv_prompt_lang": {
|
||||||
"description": "参考音频文本语言",
|
"description": "参考音频文本语言",
|
||||||
@@ -1252,19 +1303,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "启用原生搜索功能",
|
"description": "启用原生搜索功能",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"gm_native_coderunner": {
|
"gm_native_coderunner": {
|
||||||
"description": "启用原生代码执行器",
|
"description": "启用原生代码执行器",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后所有函数工具将全部失效",
|
"hint": "启用后所有函数工具将全部失效",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"gm_url_context": {
|
"gm_url_context": {
|
||||||
"description": "启用URL上下文功能",
|
"description": "启用URL上下文功能",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后所有函数工具将全部失效",
|
"hint": "启用后所有函数工具将全部失效",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"gm_safety_settings": {
|
"gm_safety_settings": {
|
||||||
"description": "安全过滤器",
|
"description": "安全过滤器",
|
||||||
@@ -1448,7 +1496,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "部署SenseVoice",
|
"description": "部署SenseVoice",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"is_emotion": {
|
"is_emotion": {
|
||||||
"description": "情绪识别",
|
"description": "情绪识别",
|
||||||
@@ -1463,18 +1510,10 @@ CONFIG_METADATA_2 = {
|
|||||||
"variables": {
|
"variables": {
|
||||||
"description": "工作流固定输入变量",
|
"description": "工作流固定输入变量",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"obvious_hint": True,
|
|
||||||
"items": {},
|
"items": {},
|
||||||
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
||||||
"invisible": True,
|
"invisible": True,
|
||||||
},
|
},
|
||||||
# "fastgpt_app_type": {
|
|
||||||
# "description": "应用类型",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "FastGPT 应用的应用类型。",
|
|
||||||
# "options": ["agent", "workflow", "plugin"],
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
"dashscope_app_type": {
|
"dashscope_app_type": {
|
||||||
"description": "应用类型",
|
"description": "应用类型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -1485,7 +1524,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"dialog-workflow",
|
"dialog-workflow",
|
||||||
"task-workflow",
|
"task-workflow",
|
||||||
],
|
],
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"timeout": {
|
"timeout": {
|
||||||
"description": "超时时间",
|
"description": "超时时间",
|
||||||
@@ -1495,26 +1533,22 @@ CONFIG_METADATA_2 = {
|
|||||||
"openai-tts-voice": {
|
"openai-tts-voice": {
|
||||||
"description": "voice",
|
"description": "voice",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||||
},
|
},
|
||||||
"fishaudio-tts-character": {
|
"fishaudio-tts-character": {
|
||||||
"description": "character",
|
"description": "character",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
||||||
},
|
},
|
||||||
"whisper_hint": {
|
"whisper_hint": {
|
||||||
"description": "本地部署 Whisper 模型须知",
|
"description": "本地部署 Whisper 模型须知",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"id": {
|
"id": {
|
||||||
"description": "ID",
|
"description": "ID",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
"hint": "模型提供商名字。",
|
||||||
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
|
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"description": "模型提供商种类",
|
"description": "模型提供商种类",
|
||||||
@@ -1529,53 +1563,27 @@ CONFIG_METADATA_2 = {
|
|||||||
"enable": {
|
"enable": {
|
||||||
"description": "启用",
|
"description": "启用",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "是否启用该模型。未启用的模型将不会被使用。",
|
"hint": "是否启用。",
|
||||||
},
|
},
|
||||||
"key": {
|
"key": {
|
||||||
"description": "API Key",
|
"description": "API Key",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。",
|
"hint": "提供商 API Key。",
|
||||||
},
|
},
|
||||||
"api_base": {
|
"api_base": {
|
||||||
"description": "API Base URL",
|
"description": "API Base URL",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
|
||||||
"base_model_path": {
|
|
||||||
"description": "基座模型路径",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "基座模型路径。",
|
|
||||||
},
|
|
||||||
"adapter_model_path": {
|
|
||||||
"description": "Adapter 模型路径",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Adapter 模型路径。如 Lora",
|
|
||||||
},
|
|
||||||
"llmtuner_template": {
|
|
||||||
"description": "template",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。",
|
|
||||||
},
|
|
||||||
"finetuning_type": {
|
|
||||||
"description": "微调类型",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "微调类型。如 `lora`",
|
|
||||||
},
|
|
||||||
"quantization_bit": {
|
|
||||||
"description": "量化位数",
|
|
||||||
"type": "int",
|
|
||||||
"hint": "量化位数。如 4",
|
|
||||||
},
|
},
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"description": "文本生成模型",
|
"description": "模型配置",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
"model": {
|
"model": {
|
||||||
"description": "模型名称",
|
"description": "模型名称",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "大语言模型的名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat 等。",
|
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||||
},
|
},
|
||||||
"max_tokens": {
|
"max_tokens": {
|
||||||
"description": "模型最大输出长度(tokens)",
|
"description": "模型最大输出长度(tokens)",
|
||||||
@@ -1622,7 +1630,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "启用大语言模型聊天",
|
"description": "启用大语言模型聊天",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"separate_provider": {
|
"separate_provider": {
|
||||||
"description": "提供商会话隔离",
|
"description": "提供商会话隔离",
|
||||||
@@ -1642,25 +1649,26 @@ CONFIG_METADATA_2 = {
|
|||||||
"web_search": {
|
"web_search": {
|
||||||
"description": "启用网页搜索",
|
"description": "启用网页搜索",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||||
},
|
},
|
||||||
"web_search_link": {
|
"web_search_link": {
|
||||||
"description": "网页搜索引用链接",
|
"description": "网页搜索引用链接",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||||
},
|
},
|
||||||
|
"display_reasoning_text": {
|
||||||
|
"description": "显示思考内容",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "开启后,将在回复中显示模型的思考过程。",
|
||||||
|
},
|
||||||
"identifier": {
|
"identifier": {
|
||||||
"description": "启动识别群员",
|
"description": "启动识别群员",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||||
},
|
},
|
||||||
"datetime_system_prompt": {
|
"datetime_system_prompt": {
|
||||||
"description": "启用日期时间系统提示",
|
"description": "启用日期时间系统提示",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||||
},
|
},
|
||||||
"default_personality": {
|
"default_personality": {
|
||||||
@@ -1688,10 +1696,19 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||||
},
|
},
|
||||||
|
"show_tool_use_status": {
|
||||||
|
"description": "函数调用状态输出",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "在触发函数调用时输出其函数名和内容。",
|
||||||
|
},
|
||||||
"streaming_segmented": {
|
"streaming_segmented": {
|
||||||
"description": "不支持流式回复的平台分段输出",
|
"description": "不支持流式回复的平台分段输出",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||||
|
},
|
||||||
|
"max_agent_step": {
|
||||||
|
"description": "工具调用轮数上限",
|
||||||
|
"type": "int",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1712,7 +1729,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "人格名称",
|
"description": "人格名称",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"description": "设定(系统提示词)",
|
"description": "设定(系统提示词)",
|
||||||
@@ -1724,14 +1740,12 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"mood_imitation_dialogs": {
|
"mood_imitation_dialogs": {
|
||||||
"description": "对话风格模仿",
|
"description": "对话风格模仿",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1743,7 +1757,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "启用语音转文本(STT)",
|
"description": "启用语音转文本(STT)",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"provider_id": {
|
"provider_id": {
|
||||||
"description": "提供商 ID",
|
"description": "提供商 ID",
|
||||||
@@ -1760,7 +1773,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "启用文本转语音(TTS)",
|
"description": "启用文本转语音(TTS)",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"provider_id": {
|
"provider_id": {
|
||||||
"description": "提供商 ID",
|
"description": "提供商 ID",
|
||||||
@@ -1771,7 +1783,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "启用语音和文字双输出",
|
"description": "启用语音和文字双输出",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||||
"obvious_hint": True,
|
|
||||||
},
|
},
|
||||||
"use_file_service": {
|
"use_file_service": {
|
||||||
"description": "使用文件服务提供 TTS 语音文件",
|
"description": "使用文件服务提供 TTS 语音文件",
|
||||||
@@ -1787,25 +1798,21 @@ CONFIG_METADATA_2 = {
|
|||||||
"group_icl_enable": {
|
"group_icl_enable": {
|
||||||
"description": "群聊内记录各群员对话",
|
"description": "群聊内记录各群员对话",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||||
},
|
},
|
||||||
"group_message_max_cnt": {
|
"group_message_max_cnt": {
|
||||||
"description": "群聊消息最大数量",
|
"description": "群聊消息最大数量",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||||
},
|
},
|
||||||
"image_caption": {
|
"image_caption": {
|
||||||
"description": "群聊图像转述(需模型支持)",
|
"description": "群聊图像转述(需模型支持)",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
|
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
|
||||||
},
|
},
|
||||||
"image_caption_provider_id": {
|
"image_caption_provider_id": {
|
||||||
"description": "图像转述提供商 ID",
|
"description": "图像转述提供商 ID",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
||||||
},
|
},
|
||||||
"image_caption_prompt": {
|
"image_caption_prompt": {
|
||||||
@@ -1819,14 +1826,12 @@ CONFIG_METADATA_2 = {
|
|||||||
"enable": {
|
"enable": {
|
||||||
"description": "启用主动回复",
|
"description": "启用主动回复",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
|
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
|
||||||
},
|
},
|
||||||
"whitelist": {
|
"whitelist": {
|
||||||
"description": "主动回复白名单",
|
"description": "主动回复白名单",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
|
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
|
||||||
},
|
},
|
||||||
"method": {
|
"method": {
|
||||||
@@ -1838,13 +1843,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"possibility_reply": {
|
"possibility_reply": {
|
||||||
"description": "回复概率",
|
"description": "回复概率",
|
||||||
"type": "float",
|
"type": "float",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||||
},
|
},
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"description": "提示词",
|
"description": "提示词",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1860,7 +1863,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "机器人唤醒前缀",
|
"description": "机器人唤醒前缀",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
||||||
},
|
},
|
||||||
"t2i": {
|
"t2i": {
|
||||||
@@ -1887,13 +1889,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"timezone": {
|
"timezone": {
|
||||||
"description": "时区",
|
"description": "时区",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||||
},
|
},
|
||||||
"callback_api_base": {
|
"callback_api_base": {
|
||||||
"description": "对外可达的回调接口地址",
|
"description": "对外可达的回调接口地址",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"obvious_hint": True,
|
|
||||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
||||||
},
|
},
|
||||||
"log_level": {
|
"log_level": {
|
||||||
@@ -1941,90 +1941,3 @@ DEFAULT_VALUE_MAP = {
|
|||||||
"list": [],
|
"list": [],
|
||||||
"object": {},
|
"object": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# "project_atri": {
|
|
||||||
# "description": "Project ATRI 配置",
|
|
||||||
# "type": "object",
|
|
||||||
# "items": {
|
|
||||||
# "enable": {"description": "启用", "type": "bool"},
|
|
||||||
# "long_term_memory": {
|
|
||||||
# "description": "长期记忆",
|
|
||||||
# "type": "object",
|
|
||||||
# "items": {
|
|
||||||
# "enable": {"description": "启用", "type": "bool"},
|
|
||||||
# "summary_threshold_cnt": {
|
|
||||||
# "description": "摘要阈值",
|
|
||||||
# "type": "int",
|
|
||||||
# "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。",
|
|
||||||
# },
|
|
||||||
# "embedding_provider_id": {
|
|
||||||
# "description": "Embedding provider ID",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# "summarize_provider_id": {
|
|
||||||
# "description": "Summary provider ID",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# "active_message": {
|
|
||||||
# "description": "主动消息",
|
|
||||||
# "type": "object",
|
|
||||||
# "items": {
|
|
||||||
# "enable": {"description": "启用", "type": "bool"},
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# "vision": {
|
|
||||||
# "description": "视觉理解",
|
|
||||||
# "type": "object",
|
|
||||||
# "items": {
|
|
||||||
# "enable": {"description": "启用", "type": "bool"},
|
|
||||||
# "provider_id_or_ofa_model_path": {
|
|
||||||
# "description": "提供商 ID 或 OFA 模型路径",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。",
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# "split_response": {
|
|
||||||
# "description": "是否分割回复",
|
|
||||||
# "type": "bool",
|
|
||||||
# "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。",
|
|
||||||
# },
|
|
||||||
# "persona": {
|
|
||||||
# "description": "人格",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# "chat_provider_id": {
|
|
||||||
# "description": "Chat provider ID",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# "chat_base_model_path": {
|
|
||||||
# "description": "用于聊天的基座模型路径",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# "chat_adapter_model_path": {
|
|
||||||
# "description": "用于聊天的 Lora 模型路径",
|
|
||||||
# "type": "string",
|
|
||||||
# "hint": "Lora 模型路径。",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# "quantization_bit": {
|
|
||||||
# "description": "量化位数",
|
|
||||||
# "type": "int",
|
|
||||||
# "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。",
|
|
||||||
# "obvious_hint": True,
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
# },
|
|
||||||
|
|||||||
@@ -88,7 +88,10 @@ class ConversationManager:
|
|||||||
return self.session_conversations.get(unified_msg_origin, None)
|
return self.session_conversations.get(unified_msg_origin, None)
|
||||||
|
|
||||||
async def get_conversation(
|
async def get_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str
|
self,
|
||||||
|
unified_msg_origin: str,
|
||||||
|
conversation_id: str,
|
||||||
|
create_if_not_exists: bool = False,
|
||||||
) -> Conversation:
|
) -> Conversation:
|
||||||
"""获取会话的对话
|
"""获取会话的对话
|
||||||
|
|
||||||
@@ -98,6 +101,13 @@ class ConversationManager:
|
|||||||
Returns:
|
Returns:
|
||||||
conversation (Conversation): 对话对象
|
conversation (Conversation): 对话对象
|
||||||
"""
|
"""
|
||||||
|
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||||
|
if not conv and create_if_not_exists:
|
||||||
|
# 如果对话不存在且需要创建,则新建一个对话
|
||||||
|
conversation_id = await self.new_conversation(unified_msg_origin)
|
||||||
|
return self.db.get_conversation_by_user_id(
|
||||||
|
unified_msg_origin, conversation_id
|
||||||
|
)
|
||||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||||
|
|
||||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||||
|
|||||||
@@ -46,9 +46,12 @@ class AstrBotCoreLifecycle:
|
|||||||
self.astrbot_config = astrbot_config # 初始化配置
|
self.astrbot_config = astrbot_config # 初始化配置
|
||||||
self.db = db # 初始化数据库
|
self.db = db # 初始化数据库
|
||||||
|
|
||||||
# 根据环境变量设置代理
|
# 设置代理
|
||||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
if self.astrbot_config.get("http_proxy", ""):
|
||||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
|
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
|
if proxy := os.environ.get("https_proxy"):
|
||||||
|
logger.debug(f"Using proxy: {proxy}")
|
||||||
os.environ["no_proxy"] = "localhost"
|
os.environ["no_proxy"] = "localhost"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
from urllib.parse import urlparse, unquote
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
class FileTokenService:
|
class FileTokenService:
|
||||||
@@ -15,7 +17,9 @@ class FileTokenService:
|
|||||||
async def _cleanup_expired_tokens(self):
|
async def _cleanup_expired_tokens(self):
|
||||||
"""清理过期的令牌"""
|
"""清理过期的令牌"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
expired_tokens = [
|
||||||
|
token for token, (_, expire) in self.staged_files.items() if expire < now
|
||||||
|
]
|
||||||
for token in expired_tokens:
|
for token in expired_tokens:
|
||||||
self.staged_files.pop(token, None)
|
self.staged_files.pop(token, None)
|
||||||
|
|
||||||
@@ -32,15 +36,35 @@ class FileTokenService:
|
|||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: 当路径不存在时抛出
|
FileNotFoundError: 当路径不存在时抛出
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 处理 file:///
|
||||||
|
try:
|
||||||
|
parsed_uri = urlparse(file_path)
|
||||||
|
if parsed_uri.scheme == "file":
|
||||||
|
local_path = unquote(parsed_uri.path)
|
||||||
|
if platform.system() == "Windows" and local_path.startswith("/"):
|
||||||
|
local_path = local_path[1:]
|
||||||
|
else:
|
||||||
|
# 如果没有 file:/// 前缀,则认为是普通路径
|
||||||
|
local_path = file_path
|
||||||
|
except Exception:
|
||||||
|
# 解析失败时,按原路径处理
|
||||||
|
local_path = file_path
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
await self._cleanup_expired_tokens()
|
await self._cleanup_expired_tokens()
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(local_path):
|
||||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
raise FileNotFoundError(
|
||||||
|
f"文件不存在: {local_path} (原始输入: {file_path})"
|
||||||
|
)
|
||||||
|
|
||||||
file_token = str(uuid.uuid4())
|
file_token = str(uuid.uuid4())
|
||||||
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
expire_time = time.time() + (
|
||||||
self.staged_files[file_token] = (file_path, expire_time)
|
timeout if timeout is not None else self.default_timeout
|
||||||
|
)
|
||||||
|
# 存储转换后的真实路径
|
||||||
|
self.staged_files[file_token] = (local_path, expire_time)
|
||||||
return file_token
|
return file_token
|
||||||
|
|
||||||
async def handle_file(self, file_token: str) -> str:
|
async def handle_file(self, file_token: str) -> str:
|
||||||
|
|||||||
@@ -96,8 +96,6 @@ class LogBroker:
|
|||||||
Queue: 订阅者的队列, 可用于接收日志消息
|
Queue: 订阅者的队列, 可用于接收日志消息
|
||||||
"""
|
"""
|
||||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||||
for log in self.log_cache:
|
|
||||||
q.put_nowait(log)
|
|
||||||
self.subscribers.append(q)
|
self.subscribers.append(q)
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,9 @@ class Plain(BaseMessageComponent):
|
|||||||
def toDict(self):
|
def toDict(self):
|
||||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||||
|
|
||||||
|
async def to_dict(self):
|
||||||
|
return {"type": "text", "data": {"text": self.text}}
|
||||||
|
|
||||||
|
|
||||||
class Face(BaseMessageComponent):
|
class Face(BaseMessageComponent):
|
||||||
type: ComponentType = "Face"
|
type: ComponentType = "Face"
|
||||||
@@ -610,6 +613,10 @@ class Node(BaseMessageComponent):
|
|||||||
"data": {"file": f"base64://{bs64}"},
|
"data": {"file": f"base64://{bs64}"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif isinstance(comp, Plain):
|
||||||
|
# For Plain segments, we need to handle the plain differently
|
||||||
|
d = await comp.to_dict()
|
||||||
|
data_content.append(d)
|
||||||
elif isinstance(comp, File):
|
elif isinstance(comp, File):
|
||||||
# For File segments, we need to handle the file differently
|
# For File segments, we need to handle the file differently
|
||||||
d = await comp.to_dict()
|
d = await comp.to_dict()
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ class MessageChain:
|
|||||||
|
|
||||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||||
|
type: Optional[str] = None
|
||||||
|
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||||
|
|
||||||
def message(self, message: str):
|
def message(self, message: str):
|
||||||
"""添加一条文本消息到消息链 `chain` 中。
|
"""添加一条文本消息到消息链 `chain` 中。
|
||||||
@@ -98,6 +100,15 @@ class MessageChain:
|
|||||||
self.chain.append(Image.fromFileSystem(path))
|
self.chain.append(Image.fromFileSystem(path))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def base64_image(self, base64_str: str):
|
||||||
|
"""添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
|
||||||
|
"""
|
||||||
|
self.chain.append(Image.fromBase64(base64_str))
|
||||||
|
return self
|
||||||
|
|
||||||
def use_t2i(self, use_t2i: bool):
|
def use_t2i(self, use_t2i: bool):
|
||||||
"""设置是否使用文本转图片服务。
|
"""设置是否使用文本转图片服务。
|
||||||
|
|
||||||
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
|
|||||||
"""普通的消息结果"""
|
"""普通的消息结果"""
|
||||||
STREAMING_RESULT = enum.auto()
|
STREAMING_RESULT = enum.auto()
|
||||||
"""调用 LLM 产生的流式结果"""
|
"""调用 LLM 产生的流式结果"""
|
||||||
STREAMING_FINISH= enum.auto()
|
STREAMING_FINISH = enum.auto()
|
||||||
"""流式输出完成"""
|
"""流式输出完成"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,24 @@
|
|||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageEventResult,
|
|
||||||
EventResultType,
|
EventResultType,
|
||||||
|
MessageEventResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .waking_check.stage import WakingCheckStage
|
|
||||||
from .whitelist_check.stage import WhitelistCheckStage
|
|
||||||
from .rate_limit_check.stage import RateLimitStage
|
|
||||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||||
from .preprocess_stage.stage import PreProcessStage
|
from .preprocess_stage.stage import PreProcessStage
|
||||||
from .process_stage.stage import ProcessStage
|
from .process_stage.stage import ProcessStage
|
||||||
from .result_decorate.stage import ResultDecorateStage
|
from .rate_limit_check.stage import RateLimitStage
|
||||||
from .respond.stage import RespondStage
|
from .respond.stage import RespondStage
|
||||||
|
from .result_decorate.stage import ResultDecorateStage
|
||||||
|
from .session_status_check.stage import SessionStatusCheckStage
|
||||||
|
from .waking_check.stage import WakingCheckStage
|
||||||
|
from .whitelist_check.stage import WhitelistCheckStage
|
||||||
|
|
||||||
# 管道阶段顺序
|
# 管道阶段顺序
|
||||||
STAGES_ORDER = [
|
STAGES_ORDER = [
|
||||||
"WakingCheckStage", # 检查是否需要唤醒
|
"WakingCheckStage", # 检查是否需要唤醒
|
||||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||||
|
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||||
"RateLimitStage", # 检查会话是否超过频率限制
|
"RateLimitStage", # 检查会话是否超过频率限制
|
||||||
"ContentSafetyCheckStage", # 检查内容安全
|
"ContentSafetyCheckStage", # 检查内容安全
|
||||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||||
@@ -29,6 +31,7 @@ STAGES_ORDER = [
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"WakingCheckStage",
|
"WakingCheckStage",
|
||||||
"WhitelistCheckStage",
|
"WhitelistCheckStage",
|
||||||
|
"SessionStatusCheckStage",
|
||||||
"RateLimitStage",
|
"RateLimitStage",
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
"PlatformCompatibilityStage",
|
"PlatformCompatibilityStage",
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
|
import inspect
|
||||||
|
import traceback
|
||||||
|
import typing as T
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star import PluginManager
|
from astrbot.core.star import PluginManager
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -9,3 +17,97 @@ class PipelineContext:
|
|||||||
|
|
||||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||||
plugin_manager: PluginManager # 插件管理器对象
|
plugin_manager: PluginManager # 插件管理器对象
|
||||||
|
|
||||||
|
async def call_event_hook(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
hook_type: EventType,
|
||||||
|
*args,
|
||||||
|
) -> bool:
|
||||||
|
"""调用事件钩子函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果事件被终止,返回 True
|
||||||
|
"""
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
hook_type, platform_id=platform_id
|
||||||
|
)
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler(event, *args)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if event.is_stopped():
|
||||||
|
logger.info(
|
||||||
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
|
||||||
|
return event.is_stopped()
|
||||||
|
|
||||||
|
async def call_handler(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
handler: T.Awaitable,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> T.AsyncGenerator[None, None]:
|
||||||
|
"""执行事件处理函数并处理其返回结果
|
||||||
|
|
||||||
|
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||||
|
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||||
|
2. 协程: 执行一次并处理返回值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
handler (Awaitable): 事件处理函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||||
|
"""
|
||||||
|
ready_to_call = None # 一个协程或者异步生成器
|
||||||
|
|
||||||
|
trace_ = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ready_to_call = handler(event, *args, **kwargs)
|
||||||
|
except TypeError as _:
|
||||||
|
# 向下兼容
|
||||||
|
trace_ = traceback.format_exc()
|
||||||
|
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
||||||
|
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
||||||
|
|
||||||
|
if inspect.isasyncgen(ready_to_call):
|
||||||
|
_has_yielded = False
|
||||||
|
try:
|
||||||
|
async for ret in ready_to_call:
|
||||||
|
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||||
|
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||||
|
_has_yielded = True
|
||||||
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||||
|
event.set_result(ret)
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
# 如果返回值是 None, 则不设置结果并继续
|
||||||
|
# 继续执行后续阶段
|
||||||
|
yield ret
|
||||||
|
if not _has_yielded:
|
||||||
|
# 如果这个异步生成器没有执行到 yield 分支
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Previous Error: {trace_}")
|
||||||
|
raise e
|
||||||
|
elif inspect.iscoroutine(ready_to_call):
|
||||||
|
# 如果只是一个协程, 直接执行
|
||||||
|
ret = await ready_to_call
|
||||||
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
event.set_result(ret)
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield ret
|
||||||
|
|||||||
58
astrbot/core/pipeline/process_stage/agent_runner/base.py
Normal file
58
astrbot/core/pipeline/process_stage/agent_runner/base.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import abc
|
||||||
|
import typing as T
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
|
from ....message.message_event_result import MessageChain
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(Enum):
|
||||||
|
"""Agent 状态枚举"""
|
||||||
|
|
||||||
|
IDLE = auto() # 初始状态
|
||||||
|
RUNNING = auto() # 运行中
|
||||||
|
DONE = auto() # 完成
|
||||||
|
ERROR = auto() # 错误状态
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponseData(T.TypedDict):
|
||||||
|
chain: MessageChain
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResponse:
|
||||||
|
type: str
|
||||||
|
data: AgentResponseData
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgentRunner:
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the agent to its initial state.
|
||||||
|
This method should be called before starting a new run.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||||
|
"""
|
||||||
|
Process a single step of the agent.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the agent has completed its task.
|
||||||
|
Returns True if the agent is done, False otherwise.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||||
|
"""
|
||||||
|
Get the final observation from the agent.
|
||||||
|
This method should be called after the agent is done.
|
||||||
|
"""
|
||||||
|
...
|
||||||
@@ -0,0 +1,306 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import typing as T
|
||||||
|
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
|
||||||
|
from ...context import PipelineContext
|
||||||
|
from astrbot.core.provider.provider import Provider
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.message.message_event_result import (
|
||||||
|
MessageChain,
|
||||||
|
)
|
||||||
|
from astrbot.core.provider.entities import (
|
||||||
|
ProviderRequest,
|
||||||
|
LLMResponse,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
AssistantMessageSegment,
|
||||||
|
ToolCallsResult,
|
||||||
|
)
|
||||||
|
from mcp.types import (
|
||||||
|
TextContent,
|
||||||
|
ImageContent,
|
||||||
|
EmbeddedResource,
|
||||||
|
TextResourceContents,
|
||||||
|
BlobResourceContents,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.star_handler import EventType
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import override
|
||||||
|
else:
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# 1. 处理平台不兼容的处理器
|
||||||
|
|
||||||
|
|
||||||
|
class ToolLoopAgent(BaseAgentRunner):
|
||||||
|
def __init__(
|
||||||
|
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
|
||||||
|
) -> None:
|
||||||
|
self.provider = provider
|
||||||
|
self.req = None
|
||||||
|
self.event = event
|
||||||
|
self.pipeline_ctx = pipeline_ctx
|
||||||
|
self._state = AgentState.IDLE
|
||||||
|
self.final_llm_resp = None
|
||||||
|
self.streaming = False
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
|
||||||
|
self.req = req
|
||||||
|
self.streaming = streaming
|
||||||
|
self.final_llm_resp = None
|
||||||
|
self._state = AgentState.IDLE
|
||||||
|
|
||||||
|
def _transition_state(self, new_state: AgentState) -> None:
|
||||||
|
"""转换 Agent 状态"""
|
||||||
|
if self._state != new_state:
|
||||||
|
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||||
|
self._state = new_state
|
||||||
|
|
||||||
|
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""Yields chunks *and* a final LLMResponse."""
|
||||||
|
if self.streaming:
|
||||||
|
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||||
|
async for resp in stream: # type: ignore
|
||||||
|
yield resp
|
||||||
|
else:
|
||||||
|
yield await self.provider.text_chat(**self.req.__dict__)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def step(self):
|
||||||
|
"""
|
||||||
|
Process a single step of the agent.
|
||||||
|
This method should return the result of the step.
|
||||||
|
"""
|
||||||
|
if not self.req:
|
||||||
|
raise ValueError("Request is not set. Please call reset() first.")
|
||||||
|
|
||||||
|
# 开始处理,转换到运行状态
|
||||||
|
self._transition_state(AgentState.RUNNING)
|
||||||
|
llm_resp_result = None
|
||||||
|
|
||||||
|
async for llm_response in self._iter_llm_responses():
|
||||||
|
assert isinstance(llm_response, LLMResponse)
|
||||||
|
if llm_response.is_chunk:
|
||||||
|
if llm_response.result_chain:
|
||||||
|
yield AgentResponse(
|
||||||
|
type="streaming_delta",
|
||||||
|
data=AgentResponseData(chain=llm_response.result_chain),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield AgentResponse(
|
||||||
|
type="streaming_delta",
|
||||||
|
data=AgentResponseData(
|
||||||
|
chain=MessageChain().message(llm_response.completion_text)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
llm_resp_result = llm_response
|
||||||
|
break # got final response
|
||||||
|
|
||||||
|
if not llm_resp_result:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理 LLM 响应
|
||||||
|
llm_resp = llm_resp_result
|
||||||
|
|
||||||
|
if llm_resp.role == "err":
|
||||||
|
# 如果 LLM 响应错误,转换到错误状态
|
||||||
|
self.final_llm_resp = llm_resp
|
||||||
|
self._transition_state(AgentState.ERROR)
|
||||||
|
yield AgentResponse(
|
||||||
|
type="err",
|
||||||
|
data=AgentResponseData(
|
||||||
|
chain=MessageChain().message(
|
||||||
|
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not llm_resp.tools_call_name:
|
||||||
|
# 如果没有工具调用,转换到完成状态
|
||||||
|
self.final_llm_resp = llm_resp
|
||||||
|
self._transition_state(AgentState.DONE)
|
||||||
|
|
||||||
|
# 执行事件钩子
|
||||||
|
if await self.pipeline_ctx.call_event_hook(
|
||||||
|
self.event, EventType.OnLLMResponseEvent, llm_resp
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 返回 LLM 结果
|
||||||
|
if llm_resp.result_chain:
|
||||||
|
yield AgentResponse(
|
||||||
|
type="llm_result",
|
||||||
|
data=AgentResponseData(chain=llm_resp.result_chain),
|
||||||
|
)
|
||||||
|
elif llm_resp.completion_text:
|
||||||
|
yield AgentResponse(
|
||||||
|
type="llm_result",
|
||||||
|
data=AgentResponseData(
|
||||||
|
chain=MessageChain().message(llm_resp.completion_text)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有工具调用,还需处理工具调用
|
||||||
|
if llm_resp.tools_call_name:
|
||||||
|
tool_call_result_blocks = []
|
||||||
|
for tool_call_name in llm_resp.tools_call_name:
|
||||||
|
yield AgentResponse(
|
||||||
|
type="tool_call",
|
||||||
|
data=AgentResponseData(
|
||||||
|
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||||
|
if isinstance(result, list):
|
||||||
|
tool_call_result_blocks = result
|
||||||
|
elif isinstance(result, MessageChain):
|
||||||
|
yield AgentResponse(
|
||||||
|
type="tool_call_result",
|
||||||
|
data=AgentResponseData(chain=result),
|
||||||
|
)
|
||||||
|
# 将结果添加到上下文中
|
||||||
|
tool_calls_result = ToolCallsResult(
|
||||||
|
tool_calls_info=AssistantMessageSegment(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=llm_resp.to_openai_tool_calls(),
|
||||||
|
content=llm_resp.completion_text,
|
||||||
|
),
|
||||||
|
tool_calls_result=tool_call_result_blocks,
|
||||||
|
)
|
||||||
|
self.req.append_tool_calls_result(tool_calls_result)
|
||||||
|
|
||||||
|
async def _handle_function_tools(
|
||||||
|
self,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
|
||||||
|
"""处理函数工具调用。"""
|
||||||
|
tool_call_result_blocks: list[ToolCallMessageSegment] = []
|
||||||
|
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
|
||||||
|
|
||||||
|
# 执行函数调用
|
||||||
|
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||||
|
llm_response.tools_call_name,
|
||||||
|
llm_response.tools_call_args,
|
||||||
|
llm_response.tools_call_ids,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if not req.func_tool:
|
||||||
|
return
|
||||||
|
func_tool = req.func_tool.get_func(func_tool_name)
|
||||||
|
if func_tool.origin == "mcp":
|
||||||
|
logger.info(
|
||||||
|
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||||
|
)
|
||||||
|
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||||
|
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||||
|
if not res:
|
||||||
|
continue
|
||||||
|
if isinstance(res.content[0], TextContent):
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=res.content[0].text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield MessageChain().message(res.content[0].text)
|
||||||
|
elif isinstance(res.content[0], ImageContent):
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content="返回了图片(已直接发送给用户)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield MessageChain(type="tool_direct_result").base64_image(
|
||||||
|
res.content[0].data
|
||||||
|
)
|
||||||
|
elif isinstance(res.content[0], EmbeddedResource):
|
||||||
|
resource = res.content[0].resource
|
||||||
|
if isinstance(resource, TextResourceContents):
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=resource.text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield MessageChain().message(resource.text)
|
||||||
|
elif (
|
||||||
|
isinstance(resource, BlobResourceContents)
|
||||||
|
and resource.mimeType
|
||||||
|
and resource.mimeType.startswith("image/")
|
||||||
|
):
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content="返回了图片(已直接发送给用户)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield MessageChain(type="tool_direct_result").base64_image(
|
||||||
|
res.content[0].data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content="返回的数据类型不受支持",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield MessageChain().message("返回的数据类型不受支持。")
|
||||||
|
else:
|
||||||
|
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||||
|
# 尝试调用工具函数
|
||||||
|
wrapper = self.pipeline_ctx.call_handler(
|
||||||
|
self.event, func_tool.handler, **func_tool_args
|
||||||
|
)
|
||||||
|
async for resp in wrapper:
|
||||||
|
if resp is not None:
|
||||||
|
# Tool 返回结果
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=resp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield MessageChain().message(resp)
|
||||||
|
else:
|
||||||
|
# Tool 直接请求发送消息给用户
|
||||||
|
# 这里我们将直接结束 Agent Loop。
|
||||||
|
self._transition_state(AgentState.DONE)
|
||||||
|
if res := self.event.get_result():
|
||||||
|
if res.chain:
|
||||||
|
yield MessageChain(
|
||||||
|
chain=res.chain, type="tool_direct_result"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.event.clear_result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(traceback.format_exc())
|
||||||
|
tool_call_result_blocks.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=f"error: {str(e)}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理函数调用响应
|
||||||
|
if tool_call_result_blocks:
|
||||||
|
yield tool_call_result_blocks
|
||||||
|
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""检查 Agent 是否已完成工作"""
|
||||||
|
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||||
|
|
||||||
|
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||||
|
return self.final_llm_resp
|
||||||
@@ -2,57 +2,47 @@
|
|||||||
本地 Agent 模式的 LLM 调用 Stage
|
本地 Agent 模式的 LLM 调用 Stage
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
from typing import Union, AsyncGenerator
|
import traceback
|
||||||
from ...context import PipelineContext
|
from typing import AsyncGenerator, Union
|
||||||
from ..stage import Stage
|
from astrbot.core import logger
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
|
MessageChain,
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
ResultContentType,
|
ResultContentType,
|
||||||
MessageChain,
|
|
||||||
)
|
)
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core import logger
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.utils.metrics import Metric
|
|
||||||
from astrbot.core.provider.entities import (
|
from astrbot.core.provider.entities import (
|
||||||
ProviderRequest,
|
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
ToolCallMessageSegment,
|
ProviderRequest,
|
||||||
AssistantMessageSegment,
|
|
||||||
ToolCallsResult,
|
|
||||||
)
|
)
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star_handler import EventType
|
||||||
from mcp.types import (
|
from astrbot.core.utils.metrics import Metric
|
||||||
TextContent,
|
from ...context import PipelineContext
|
||||||
ImageContent,
|
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||||
EmbeddedResource,
|
from ..stage import Stage
|
||||||
TextResourceContents,
|
|
||||||
BlobResourceContents,
|
|
||||||
)
|
|
||||||
from astrbot.core import web_chat_back_queue
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestSubStage(Stage):
|
class LLMRequestSubStage(Stage):
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list
|
conf = ctx.astrbot_config
|
||||||
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
settings = conf["provider_settings"]
|
||||||
"wake_prefix"
|
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||||
] # str
|
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||||
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
self.max_context_length = settings["max_context_length"] # int
|
||||||
"max_context_length"
|
self.dequeue_context_length: int = min(
|
||||||
] # int
|
max(1, settings["dequeue_context_length"]),
|
||||||
self.dequeue_context_length = min(
|
|
||||||
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
|
||||||
self.max_context_length - 1,
|
self.max_context_length - 1,
|
||||||
) # int
|
)
|
||||||
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
self.streaming_response: bool = settings["streaming_response"]
|
||||||
"streaming_response"
|
self.max_step: int = settings.get("max_agent_step", 10)
|
||||||
] # bool
|
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||||
|
|
||||||
for bwp in self.bot_wake_prefixs:
|
for bwp in self.bot_wake_prefixs:
|
||||||
if self.provider_wake_prefix.startswith(bwp):
|
if self.provider_wake_prefix.startswith(bwp):
|
||||||
@@ -63,16 +53,33 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||||
|
|
||||||
|
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
|
||||||
|
"""选择使用的 LLM 提供商"""
|
||||||
|
sel_provider = event.get_extra("selected_provider")
|
||||||
|
_ctx = self.ctx.plugin_manager.context
|
||||||
|
if sel_provider and isinstance(sel_provider, str):
|
||||||
|
provider = _ctx.get_provider_by_id(sel_provider)
|
||||||
|
if not provider:
|
||||||
|
logger.error(f"未找到指定的提供商: {sel_provider}。")
|
||||||
|
return provider
|
||||||
|
|
||||||
|
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent, _nested: bool = False
|
self, event: AstrMessageEvent, _nested: bool = False
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
req: ProviderRequest = None
|
req: ProviderRequest | None = None
|
||||||
|
|
||||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||||
return
|
return
|
||||||
umo = event.unified_msg_origin
|
|
||||||
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
|
# 检查会话级别的LLM启停状态
|
||||||
|
if not SessionServiceManager.should_process_llm_request(event):
|
||||||
|
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||||
|
return
|
||||||
|
|
||||||
|
provider = self._select_provider(event)
|
||||||
if provider is None:
|
if provider is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -83,13 +90,12 @@ class LLMRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if req.conversation:
|
if req.conversation:
|
||||||
all_contexts = json.loads(req.conversation.history)
|
req.contexts = json.loads(req.conversation.history)
|
||||||
req.contexts = self._process_tool_message_pairs(
|
|
||||||
all_contexts, remove_tags=True
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
req = ProviderRequest(prompt="", image_urls=[])
|
req = ProviderRequest(prompt="", image_urls=[])
|
||||||
|
if sel_model := event.get_extra("selected_model"):
|
||||||
|
req.model = sel_model
|
||||||
if self.provider_wake_prefix:
|
if self.provider_wake_prefix:
|
||||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||||
return
|
return
|
||||||
@@ -127,26 +133,8 @@ class LLMRequestSubStage(Stage):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 执行请求 LLM 前事件钩子。
|
# 执行请求 LLM 前事件钩子。
|
||||||
# 装饰 system_prompt 等功能
|
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||||
# 获取当前平台ID
|
return
|
||||||
platform_id = event.get_platform_id()
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
|
||||||
EventType.OnLLMRequestEvent, platform_id=platform_id
|
|
||||||
)
|
|
||||||
for handler in handlers:
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
|
||||||
)
|
|
||||||
await handler.handler(event, req)
|
|
||||||
except BaseException:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
if event.is_stopped():
|
|
||||||
logger.info(
|
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(req.contexts, str):
|
if isinstance(req.contexts, str):
|
||||||
req.contexts = json.loads(req.contexts)
|
req.contexts = json.loads(req.contexts)
|
||||||
@@ -176,77 +164,77 @@ class LLMRequestSubStage(Stage):
|
|||||||
if not req.session_id:
|
if not req.session_id:
|
||||||
req.session_id = event.unified_msg_origin
|
req.session_id = event.unified_msg_origin
|
||||||
|
|
||||||
async def requesting(req: ProviderRequest):
|
# fix messages
|
||||||
try:
|
req.contexts = self.fix_messages(req.contexts)
|
||||||
need_loop = True
|
|
||||||
while need_loop:
|
|
||||||
need_loop = False
|
|
||||||
logger.debug(f"提供商请求 Payload: {req}")
|
|
||||||
|
|
||||||
final_llm_response = None
|
# Call Agent
|
||||||
|
tool_loop_agent = ToolLoopAgent(
|
||||||
if self.streaming_response:
|
provider=provider,
|
||||||
stream = provider.text_chat_stream(**req.__dict__)
|
event=event,
|
||||||
async for llm_response in stream:
|
pipeline_ctx=self.ctx,
|
||||||
if llm_response.is_chunk:
|
)
|
||||||
if llm_response.result_chain:
|
logger.debug(
|
||||||
yield llm_response.result_chain # MessageChain
|
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
||||||
else:
|
)
|
||||||
yield MessageChain().message(
|
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||||
llm_response.completion_text
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
final_llm_response = llm_response
|
|
||||||
else:
|
|
||||||
final_llm_response = await provider.text_chat(
|
|
||||||
**req.__dict__
|
|
||||||
) # 请求 LLM
|
|
||||||
|
|
||||||
if not final_llm_response:
|
|
||||||
raise Exception("LLM response is None.")
|
|
||||||
|
|
||||||
# 执行 LLM 响应后的事件钩子。
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
|
||||||
EventType.OnLLMResponseEvent
|
|
||||||
)
|
|
||||||
for handler in handlers:
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
|
||||||
)
|
|
||||||
await handler.handler(event, final_llm_response)
|
|
||||||
except BaseException:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
|
async def requesting():
|
||||||
|
step_idx = 0
|
||||||
|
while step_idx < self.max_step:
|
||||||
|
step_idx += 1
|
||||||
|
try:
|
||||||
|
async for resp in tool_loop_agent.step():
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.info(
|
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
if resp.type == "tool_call_result":
|
||||||
|
msg_chain = resp.data["chain"]
|
||||||
|
if msg_chain.type == "tool_direct_result":
|
||||||
|
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||||
|
resp.data["chain"].type = "tool_call_result"
|
||||||
|
await event.send(resp.data["chain"])
|
||||||
|
continue
|
||||||
|
# 对于其他情况,暂时先不处理
|
||||||
|
continue
|
||||||
|
elif resp.type == "tool_call":
|
||||||
|
if self.streaming_response:
|
||||||
|
# 用来标记流式响应需要分节
|
||||||
|
yield MessageChain(chain=[], type="break")
|
||||||
|
if (
|
||||||
|
self.show_tool_use
|
||||||
|
or event.get_platform_name() == "webchat"
|
||||||
|
):
|
||||||
|
resp.data["chain"].type = "tool_call"
|
||||||
|
await event.send(resp.data["chain"])
|
||||||
|
continue
|
||||||
|
|
||||||
if self.streaming_response:
|
if not self.streaming_response:
|
||||||
# 流式输出的处理
|
content_typ = (
|
||||||
async for result in self._handle_llm_stream_response(
|
ResultContentType.LLM_RESULT
|
||||||
event, req, final_llm_response
|
if resp.type == "llm_result"
|
||||||
):
|
else ResultContentType.GENERAL_RESULT
|
||||||
if isinstance(result, ProviderRequest):
|
)
|
||||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
event.set_result(
|
||||||
req = result
|
MessageEventResult(
|
||||||
need_loop = True
|
chain=resp.data["chain"].chain,
|
||||||
else:
|
result_content_type=content_typ,
|
||||||
yield
|
)
|
||||||
else:
|
)
|
||||||
# 非流式输出的处理
|
yield
|
||||||
async for result in self._handle_llm_response(
|
event.clear_result()
|
||||||
event, req, final_llm_response
|
else:
|
||||||
):
|
if resp.type == "streaming_delta":
|
||||||
if isinstance(result, ProviderRequest):
|
yield resp.data["chain"] # MessageChain
|
||||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
if tool_loop_agent.done():
|
||||||
req = result
|
break
|
||||||
need_loop = True
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
Metric.upload(
|
Metric.upload(
|
||||||
llm_tick=1,
|
llm_tick=1,
|
||||||
@@ -255,45 +243,41 @@ class LLMRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存到历史记录
|
if self.streaming_response:
|
||||||
await self._save_to_history(event, req, final_llm_response)
|
# 流式响应
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(
|
|
||||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.streaming_response:
|
|
||||||
event.set_extra("tool_call_result", None)
|
|
||||||
async for _ in requesting(req):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult()
|
MessageEventResult()
|
||||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||||
.set_async_stream(requesting(req))
|
.set_async_stream(requesting())
|
||||||
)
|
)
|
||||||
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
|
||||||
yield
|
yield
|
||||||
|
if tool_loop_agent.done():
|
||||||
if event.get_extra("tool_call_result"):
|
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
|
||||||
event.set_result(event.get_extra("tool_call_result"))
|
if final_llm_resp.completion_text:
|
||||||
event.set_extra("tool_call_result", None)
|
chain = (
|
||||||
|
MessageChain().message(final_llm_resp.completion_text).chain
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chain = final_llm_resp.result_chain.chain
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult(
|
||||||
|
chain=chain,
|
||||||
|
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async for _ in requesting():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 暂时直接发出去
|
# 异步处理 WebChat 特殊情况
|
||||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
|
||||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
|
||||||
event.set_extra("tool_call_img_respond", None)
|
|
||||||
|
|
||||||
if event.get_platform_name() == "webchat":
|
if event.get_platform_name() == "webchat":
|
||||||
# 异步处理 WebChat 特殊情况
|
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||||
asyncio.create_task(self._handle_webchat(event, req))
|
|
||||||
|
|
||||||
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
|
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
|
||||||
|
|
||||||
|
async def _handle_webchat(
|
||||||
|
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||||
|
):
|
||||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||||
conversation = await self.conv_manager.get_conversation(
|
conversation = await self.conv_manager.get_conversation(
|
||||||
event.unified_msg_origin, req.conversation.cid
|
event.unified_msg_origin, req.conversation.cid
|
||||||
@@ -303,21 +287,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
latest_pair = messages[-2:]
|
latest_pair = messages[-2:]
|
||||||
if not latest_pair:
|
if not latest_pair:
|
||||||
return
|
return
|
||||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
|
||||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||||
# if len(latest_pair) > 1:
|
|
||||||
# cleaned_text += (
|
|
||||||
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
|
|
||||||
# )
|
|
||||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||||
llm_resp = await provider.text_chat(
|
llm_resp = await prov.text_chat(
|
||||||
system_prompt="You are expert in summarizing user's query.",
|
system_prompt="You are expert in summarizing user's query.",
|
||||||
prompt=(
|
prompt=(
|
||||||
f"Please summarize the following query of user:\n"
|
f"Please summarize the following query of user:\n"
|
||||||
f"{cleaned_text}\n"
|
f"{cleaned_text}\n"
|
||||||
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
||||||
"You must use the same language as the user."
|
"You must use the same language as the user."
|
||||||
"If you think the dialog is too short to summarize, only output a special mark: `None`"
|
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if llm_resp and llm_resp.completion_text:
|
if llm_resp and llm_resp.completion_text:
|
||||||
@@ -325,7 +304,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
||||||
)
|
)
|
||||||
title = llm_resp.completion_text.strip()
|
title = llm_resp.completion_text.strip()
|
||||||
if not title or "None" == title:
|
if not title or "<None>" in title:
|
||||||
return
|
return
|
||||||
await self.conv_manager.update_conversation_title(
|
await self.conv_manager.update_conversation_title(
|
||||||
event.unified_msg_origin, title=title
|
event.unified_msg_origin, title=title
|
||||||
@@ -341,330 +320,50 @@ class LLMRequestSubStage(Stage):
|
|||||||
cid=cid,
|
cid=cid,
|
||||||
title=title,
|
title=title,
|
||||||
)
|
)
|
||||||
web_chat_back_queue.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "update_title",
|
|
||||||
"cid": cid,
|
|
||||||
"data": title,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_llm_response(
|
|
||||||
self,
|
|
||||||
event: AstrMessageEvent,
|
|
||||||
req: ProviderRequest,
|
|
||||||
llm_response: LLMResponse,
|
|
||||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
|
||||||
"""处理非流式 LLM 响应。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
|
||||||
"""
|
|
||||||
if llm_response.role == "assistant":
|
|
||||||
# text completion
|
|
||||||
if llm_response.result_chain:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult(
|
|
||||||
chain=llm_response.result_chain.chain
|
|
||||||
).set_result_content_type(ResultContentType.LLM_RESULT)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult()
|
|
||||||
.message(llm_response.completion_text)
|
|
||||||
.set_result_content_type(ResultContentType.LLM_RESULT)
|
|
||||||
)
|
|
||||||
elif llm_response.role == "err":
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(
|
|
||||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif llm_response.role == "tool":
|
|
||||||
# 处理函数工具调用
|
|
||||||
async for result in self._handle_function_tools(event, req, llm_response):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
async def _handle_llm_stream_response(
|
|
||||||
self,
|
|
||||||
event: AstrMessageEvent,
|
|
||||||
req: ProviderRequest,
|
|
||||||
llm_response: LLMResponse,
|
|
||||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
|
||||||
"""处理流式 LLM 响应。
|
|
||||||
|
|
||||||
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
|
||||||
"""
|
|
||||||
if llm_response.role == "assistant":
|
|
||||||
# text completion
|
|
||||||
if llm_response.result_chain:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult(
|
|
||||||
chain=llm_response.result_chain.chain
|
|
||||||
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult()
|
|
||||||
.message(llm_response.completion_text)
|
|
||||||
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
|
||||||
)
|
|
||||||
elif llm_response.role == "err":
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(
|
|
||||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif llm_response.role == "tool":
|
|
||||||
# 处理函数工具调用
|
|
||||||
async for result in self._handle_function_tools(event, req, llm_response):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
async def _handle_function_tools(
|
|
||||||
self,
|
|
||||||
event: AstrMessageEvent,
|
|
||||||
req: ProviderRequest,
|
|
||||||
llm_response: LLMResponse,
|
|
||||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
|
||||||
"""处理函数工具调用。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
|
||||||
"""
|
|
||||||
# function calling
|
|
||||||
tool_call_result: list[ToolCallMessageSegment] = []
|
|
||||||
logger.info(
|
|
||||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
|
||||||
)
|
|
||||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
|
||||||
llm_response.tools_call_name,
|
|
||||||
llm_response.tools_call_args,
|
|
||||||
llm_response.tools_call_ids,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
func_tool = req.func_tool.get_func(func_tool_name)
|
|
||||||
if func_tool.origin == "mcp":
|
|
||||||
logger.info(
|
|
||||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
|
||||||
)
|
|
||||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
|
||||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
|
||||||
if res:
|
|
||||||
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
|
||||||
if isinstance(res.content[0], TextContent):
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content=res.content[0].text,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(res.content[0], ImageContent):
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content="返回了图片(已直接发送给用户)",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event.set_extra(
|
|
||||||
"tool_call_img_respond",
|
|
||||||
res.content[0].data,
|
|
||||||
)
|
|
||||||
elif isinstance(res.content[0], EmbeddedResource):
|
|
||||||
resource = res.content[0].resource
|
|
||||||
if isinstance(resource, TextResourceContents):
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content=resource.text,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
isinstance(resource, BlobResourceContents)
|
|
||||||
and resource.mimeType
|
|
||||||
and resource.mimeType.startswith("image/")
|
|
||||||
):
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content="返回了图片(已直接发送给用户)",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event.set_extra(
|
|
||||||
"tool_call_img_respond",
|
|
||||||
res.content[0].data,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content="返回的数据类型不受支持",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 获取处理器,过滤掉平台不兼容的处理器
|
|
||||||
platform_id = event.get_platform_id()
|
|
||||||
star_md = star_map.get(func_tool.handler_module_path)
|
|
||||||
if (
|
|
||||||
star_md
|
|
||||||
and platform_id in star_md.supported_platforms
|
|
||||||
and not star_md.supported_platforms[platform_id]
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
|
||||||
)
|
|
||||||
# 直接跳过,不添加任何消息到tool_call_result
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
|
||||||
)
|
|
||||||
# 尝试调用工具函数
|
|
||||||
wrapper = self._call_handler(
|
|
||||||
self.ctx, event, func_tool.handler, **func_tool_args
|
|
||||||
)
|
|
||||||
async for resp in wrapper:
|
|
||||||
if resp is not None: # 有 return 返回
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content=resp,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
res = event.get_result()
|
|
||||||
if res and res.chain:
|
|
||||||
event.set_extra("tool_call_result", res)
|
|
||||||
yield # 有生成器返回
|
|
||||||
event.clear_result() # 清除上一个 handler 的结果
|
|
||||||
except BaseException as e:
|
|
||||||
logger.warning(traceback.format_exc())
|
|
||||||
tool_call_result.append(
|
|
||||||
ToolCallMessageSegment(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=func_tool_id,
|
|
||||||
content=f"error: {str(e)}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if tool_call_result:
|
|
||||||
# 函数调用结果
|
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
|
||||||
assistant_msg_seg = AssistantMessageSegment(
|
|
||||||
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
|
||||||
)
|
|
||||||
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
|
||||||
req.tool_calls_result = ToolCallsResult(
|
|
||||||
tool_calls_info=assistant_msg_seg,
|
|
||||||
tool_calls_result=tool_call_result,
|
|
||||||
)
|
|
||||||
yield req # 再次执行 LLM 请求
|
|
||||||
else:
|
|
||||||
if llm_response.completion_text:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(llm_response.completion_text)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _save_to_history(
|
async def _save_to_history(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse | None,
|
||||||
):
|
):
|
||||||
if not req or not req.conversation or not llm_response:
|
if (
|
||||||
|
not req
|
||||||
|
or not req.conversation
|
||||||
|
or not llm_response
|
||||||
|
or llm_response.role != "assistant"
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
if llm_response.role == "assistant":
|
# 历史上下文
|
||||||
# 文本回复
|
messages = copy.deepcopy(req.contexts)
|
||||||
contexts = req.contexts.copy()
|
# 这一轮对话请求的用户输入
|
||||||
contexts.append(await req.assemble_context())
|
messages.append(await req.assemble_context())
|
||||||
|
# 这一轮对话的 LLM 响应
|
||||||
|
if req.tool_calls_result:
|
||||||
|
if not isinstance(req.tool_calls_result, list):
|
||||||
|
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||||
|
elif isinstance(req.tool_calls_result, list):
|
||||||
|
for tcr in req.tool_calls_result:
|
||||||
|
messages.extend(tcr.to_openai_messages())
|
||||||
|
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||||
|
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||||
|
await self.conv_manager.update_conversation(
|
||||||
|
event.unified_msg_origin, req.conversation.cid, history=messages
|
||||||
|
)
|
||||||
|
|
||||||
# 记录并标记函数调用结果
|
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||||
if req.tool_calls_result:
|
"""验证并且修复上下文"""
|
||||||
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
fixed_messages = []
|
||||||
|
for message in messages:
|
||||||
# 添加标记
|
if message.get("role") == "tool":
|
||||||
for message in tool_calls_messages:
|
# tool block 前面必须要有 user 和 assistant block
|
||||||
message["_tool_call_history"] = True
|
if len(fixed_messages) < 2:
|
||||||
|
# 这种情况可能是上下文被截断导致的
|
||||||
processed_tool_messages = self._process_tool_message_pairs(
|
# 我们直接将之前的上下文都清空
|
||||||
tool_calls_messages, remove_tags=False
|
fixed_messages = []
|
||||||
)
|
else:
|
||||||
|
fixed_messages.append(message)
|
||||||
contexts.extend(processed_tool_messages)
|
|
||||||
|
|
||||||
contexts.append(
|
|
||||||
{"role": "assistant", "content": llm_response.completion_text}
|
|
||||||
)
|
|
||||||
contexts_to_save = list(
|
|
||||||
filter(lambda item: "_no_save" not in item, contexts)
|
|
||||||
)
|
|
||||||
await self.conv_manager.update_conversation(
|
|
||||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
|
||||||
"""处理工具调用消息,确保assistant和tool消息成对出现
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (list): 消息列表
|
|
||||||
remove_tags (bool): 是否移除_tool_call_history标记
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
i = 0
|
|
||||||
|
|
||||||
while i < len(messages):
|
|
||||||
current_msg = messages[i]
|
|
||||||
|
|
||||||
# 普通消息直接添加
|
|
||||||
if "_tool_call_history" not in current_msg:
|
|
||||||
result.append(current_msg.copy() if remove_tags else current_msg)
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 工具调用消息成对处理
|
|
||||||
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
|
||||||
assistant_msg = current_msg.copy()
|
|
||||||
|
|
||||||
if remove_tags and "_tool_call_history" in assistant_msg:
|
|
||||||
del assistant_msg["_tool_call_history"]
|
|
||||||
|
|
||||||
related_tools = []
|
|
||||||
j = i + 1
|
|
||||||
while (
|
|
||||||
j < len(messages)
|
|
||||||
and messages[j].get("role") == "tool"
|
|
||||||
and "_tool_call_history" in messages[j]
|
|
||||||
):
|
|
||||||
tool_msg = messages[j].copy()
|
|
||||||
|
|
||||||
if remove_tags:
|
|
||||||
del tool_msg["_tool_call_history"]
|
|
||||||
|
|
||||||
related_tools.append(tool_msg)
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
# 成对的时候添加到结果
|
|
||||||
if related_tools:
|
|
||||||
result.append(assistant_msg)
|
|
||||||
result.extend(related_tools)
|
|
||||||
|
|
||||||
i = j # 跳过已处理
|
|
||||||
else:
|
else:
|
||||||
# 单独的tool消息
|
fixed_messages.append(message)
|
||||||
i += 1
|
return fixed_messages
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class StarRequestSubStage(Stage):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||||
)
|
)
|
||||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
wrapper = self.ctx.call_handler(event, handler.handler, **params)
|
||||||
async for ret in wrapper:
|
async for ret in wrapper:
|
||||||
yield ret
|
yield ret
|
||||||
event.clear_result() # 清除上一个 handler 的结果
|
event.clear_result() # 清除上一个 handler 的结果
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
|
|||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.utils.path_util import path_Mapping
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
@@ -128,9 +129,7 @@ class RespondStage(Stage):
|
|||||||
"streaming_segmented", False
|
"streaming_segmented", False
|
||||||
)
|
)
|
||||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||||
await event._pre_send()
|
|
||||||
await event.send_streaming(result.async_stream, use_fallback)
|
await event.send_streaming(result.async_stream, use_fallback)
|
||||||
await event._post_send()
|
|
||||||
return
|
return
|
||||||
elif len(result.chain) > 0:
|
elif len(result.chain) > 0:
|
||||||
# 检查路径映射
|
# 检查路径映射
|
||||||
@@ -141,8 +140,6 @@ class RespondStage(Stage):
|
|||||||
component.file = path_Mapping(mappings, component.file)
|
component.file = path_Mapping(mappings, component.file)
|
||||||
event.get_result().chain[idx] = component
|
event.get_result().chain[idx] = component
|
||||||
|
|
||||||
await event._pre_send()
|
|
||||||
|
|
||||||
# 检查消息链是否为空
|
# 检查消息链是否为空
|
||||||
try:
|
try:
|
||||||
if await self._is_empty_message_chain(result.chain):
|
if await self._is_empty_message_chain(result.chain):
|
||||||
@@ -158,9 +155,14 @@ class RespondStage(Stage):
|
|||||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.enable_seg and (
|
if (
|
||||||
(self.only_llm_result and result.is_llm_result())
|
self.enable_seg
|
||||||
or not self.only_llm_result
|
and (
|
||||||
|
(self.only_llm_result and result.is_llm_result())
|
||||||
|
or not self.only_llm_result
|
||||||
|
)
|
||||||
|
and event.get_platform_name()
|
||||||
|
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
||||||
):
|
):
|
||||||
decorated_comps = []
|
decorated_comps = []
|
||||||
if self.reply_with_mention:
|
if self.reply_with_mention:
|
||||||
@@ -176,25 +178,26 @@ class RespondStage(Stage):
|
|||||||
result.chain.remove(comp)
|
result.chain.remove(comp)
|
||||||
break
|
break
|
||||||
|
|
||||||
for rcomp in record_comps:
|
# leverage lock to guarentee the order of message sending among different events
|
||||||
i = await self._calc_comp_interval(rcomp)
|
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||||
await asyncio.sleep(i)
|
for rcomp in record_comps:
|
||||||
try:
|
i = await self._calc_comp_interval(rcomp)
|
||||||
await event.send(MessageChain([rcomp]))
|
await asyncio.sleep(i)
|
||||||
except Exception as e:
|
try:
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
await event.send(MessageChain([rcomp]))
|
||||||
break
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
# 分段回复
|
break
|
||||||
for comp in non_record_comps:
|
# 分段回复
|
||||||
i = await self._calc_comp_interval(comp)
|
for comp in non_record_comps:
|
||||||
await asyncio.sleep(i)
|
i = await self._calc_comp_interval(comp)
|
||||||
try:
|
await asyncio.sleep(i)
|
||||||
await event.send(MessageChain([*decorated_comps, comp]))
|
try:
|
||||||
decorated_comps = [] # 清空已发送的装饰组件
|
await event.send(MessageChain([*decorated_comps, comp]))
|
||||||
except Exception as e:
|
decorated_comps = [] # 清空已发送的装饰组件
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
except Exception as e:
|
||||||
break
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
for rcomp in record_comps:
|
for rcomp in record_comps:
|
||||||
try:
|
try:
|
||||||
@@ -208,7 +211,6 @@ class RespondStage(Stage):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
|
||||||
await event._post_send()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import AsyncGenerator, Union
|
from typing import AsyncGenerator, Union
|
||||||
|
|
||||||
from astrbot.core import html_renderer, logger, file_token_service
|
from astrbot.core import file_token_service, html_renderer, logger
|
||||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||||
from astrbot.core.message.message_event_result import ResultContentType
|
from astrbot.core.message.message_event_result import ResultContentType
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||||
|
|
||||||
@@ -141,7 +142,11 @@ class ResultDecorateStage(Stage):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# 分段回复
|
# 分段回复
|
||||||
if self.enable_segmented_reply:
|
if self.enable_segmented_reply and event.get_platform_name() not in [
|
||||||
|
"qq_official",
|
||||||
|
"weixin_official_account",
|
||||||
|
"dingtalk",
|
||||||
|
]:
|
||||||
if (
|
if (
|
||||||
self.only_llm_result and result.is_llm_result()
|
self.only_llm_result and result.is_llm_result()
|
||||||
) or not self.only_llm_result:
|
) or not self.only_llm_result:
|
||||||
@@ -172,10 +177,12 @@ class ResultDecorateStage(Stage):
|
|||||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||||
event.unified_msg_origin
|
event.unified_msg_origin
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
and tts_provider
|
and tts_provider
|
||||||
|
and SessionServiceManager.should_process_tts_request(event)
|
||||||
):
|
):
|
||||||
new_chain = []
|
new_chain = []
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class PipelineScheduler:
|
|||||||
await self._process_stages(event)
|
await self._process_stages(event)
|
||||||
|
|
||||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
if event.get_platform_name() == "webchat":
|
||||||
await event.send(None)
|
await event.send(None)
|
||||||
|
|
||||||
logger.debug("pipeline 执行完毕。")
|
logger.debug("pipeline 执行完毕。")
|
||||||
|
|||||||
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from ..stage import Stage, register_stage
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from typing import AsyncGenerator, Union
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
@register_stage
|
||||||
|
class SessionStatusCheckStage(Stage):
|
||||||
|
"""检查会话是否整体启用"""
|
||||||
|
|
||||||
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
# 检查会话是否整体启用
|
||||||
|
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||||
|
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||||
|
event.stop_event()
|
||||||
@@ -1,12 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import inspect
|
from typing import List, AsyncGenerator, Union
|
||||||
import traceback
|
|
||||||
from astrbot.api import logger
|
|
||||||
from typing import List, AsyncGenerator, Union, Awaitable
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from .context import PipelineContext
|
from .context import PipelineContext
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
|
||||||
|
|
||||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||||
|
|
||||||
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
|
|||||||
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _call_handler(
|
|
||||||
self,
|
|
||||||
ctx: PipelineContext,
|
|
||||||
event: AstrMessageEvent,
|
|
||||||
handler: Awaitable,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> AsyncGenerator[None, None]:
|
|
||||||
"""执行事件处理函数并处理其返回结果
|
|
||||||
|
|
||||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
|
||||||
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
|
||||||
2. 协程: 执行一次并处理返回值
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx (PipelineContext): 消息管道上下文对象
|
|
||||||
event (AstrMessageEvent): 待处理的事件对象
|
|
||||||
handler (Awaitable): 事件处理函数
|
|
||||||
*args: 传递给handler的位置参数
|
|
||||||
**kwargs: 传递给handler的关键字参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
|
||||||
"""
|
|
||||||
ready_to_call = None # 一个协程或者异步生成器(async def)
|
|
||||||
|
|
||||||
trace_ = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
ready_to_call = handler(event, *args, **kwargs)
|
|
||||||
except TypeError as _:
|
|
||||||
# 向下兼容
|
|
||||||
trace_ = traceback.format_exc()
|
|
||||||
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
|
||||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
|
||||||
|
|
||||||
if isinstance(ready_to_call, AsyncGenerator):
|
|
||||||
# 如果是一个异步生成器, 进入洋葱模型
|
|
||||||
_has_yielded = False # 是否返回过值
|
|
||||||
try:
|
|
||||||
async for ret in ready_to_call:
|
|
||||||
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
|
||||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
|
||||||
_has_yielded = True
|
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
|
||||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
|
||||||
event.set_result(ret)
|
|
||||||
yield # 传递控制权给上一层的process函数
|
|
||||||
else:
|
|
||||||
# 如果返回值是 None, 则不设置结果并继续
|
|
||||||
# 继续执行后续阶段
|
|
||||||
yield ret # 传递控制权给上一层的process函数
|
|
||||||
if not _has_yielded:
|
|
||||||
# 如果这个异步生成器没有执行到yield分支
|
|
||||||
yield
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Previous Error: {trace_}")
|
|
||||||
raise e
|
|
||||||
elif inspect.iscoroutine(ready_to_call):
|
|
||||||
# 如果只是一个协程, 直接执行
|
|
||||||
ret = await ready_to_call
|
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
|
||||||
event.set_result(ret)
|
|
||||||
yield # 传递控制权给上一层的process函数
|
|
||||||
else:
|
|
||||||
yield ret # 传递控制权给上一层的process函数
|
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from ..stage import Stage, register_stage
|
from typing import AsyncGenerator, Union
|
||||||
from ..context import PipelineContext
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from typing import Union, AsyncGenerator
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
|
||||||
from astrbot.core.message.components import At, AtAll, Reply
|
from astrbot.core.message.components import At, AtAll, Reply
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||||
|
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||||
|
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
@@ -135,7 +138,6 @@ class WakingCheckStage(Stage):
|
|||||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await event._post_send()
|
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
passed = False
|
passed = False
|
||||||
break
|
break
|
||||||
@@ -150,7 +152,6 @@ class WakingCheckStage(Stage):
|
|||||||
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await event._post_send()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||||
)
|
)
|
||||||
@@ -166,7 +167,12 @@ class WakingCheckStage(Stage):
|
|||||||
"parsed_params"
|
"parsed_params"
|
||||||
)
|
)
|
||||||
|
|
||||||
event.clear_extra()
|
event._extras.pop("parsed_params", None)
|
||||||
|
|
||||||
|
# 根据会话配置过滤插件处理器
|
||||||
|
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||||
|
event, activated_handlers
|
||||||
|
)
|
||||||
|
|
||||||
event.set_extra("activated_handlers", activated_handlers)
|
event.set_extra("activated_handlers", activated_handlers)
|
||||||
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
):
|
):
|
||||||
"""发送流式消息到消息平台,使用异步生成器。
|
"""发送流式消息到消息平台,使用异步生成器。
|
||||||
目前仅支持: telegram,qq official 私聊。
|
目前仅支持: telegram,qq official 私聊。
|
||||||
Fallback仅支持 aiocqhttp, gewechat。
|
Fallback仅支持 aiocqhttp。
|
||||||
"""
|
"""
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||||
@@ -235,10 +235,10 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
self._has_send_oper = True
|
self._has_send_oper = True
|
||||||
|
|
||||||
async def _pre_send(self):
|
async def _pre_send(self):
|
||||||
"""调度器会在执行 send() 前调用该方法"""
|
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
|
||||||
|
|
||||||
async def _post_send(self):
|
async def _post_send(self):
|
||||||
"""调度器会在执行 send() 后调用该方法"""
|
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
|
||||||
|
|
||||||
def set_result(self, result: Union[MessageEventResult, str]):
|
def set_result(self, result: Union[MessageEventResult, str]):
|
||||||
"""设置消息事件的结果。
|
"""设置消息事件的结果。
|
||||||
@@ -419,7 +419,6 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
|
|
||||||
适配情况:
|
适配情况:
|
||||||
|
|
||||||
- gewechat
|
|
||||||
- aiocqhttp(OneBotv11)
|
- aiocqhttp(OneBotv11)
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -58,10 +58,6 @@ class PlatformManager:
|
|||||||
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||||
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "gewechat":
|
|
||||||
from .sources.gewechat.gewechat_platform_adapter import (
|
|
||||||
GewechatPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
case "wechatpadpro":
|
case "wechatpadpro":
|
||||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||||
WeChatPadProAdapter, # noqa: F401
|
WeChatPadProAdapter, # noqa: F401
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import AsyncGenerator, Dict, List
|
from typing import AsyncGenerator, Dict, List
|
||||||
from aiocqhttp import CQHttp
|
from aiocqhttp import CQHttp, Event
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import (
|
from astrbot.api.message_components import (
|
||||||
Image,
|
Image,
|
||||||
@@ -58,50 +58,85 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
ret.append(d)
|
ret.append(d)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
@classmethod
|
||||||
|
async def _dispatch_send(
|
||||||
|
cls,
|
||||||
|
bot: CQHttp,
|
||||||
|
event: Event | None,
|
||||||
|
is_group: bool,
|
||||||
|
session_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
):
|
||||||
|
if event:
|
||||||
|
await bot.send(event=event, message=messages)
|
||||||
|
elif is_group:
|
||||||
|
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||||
|
else:
|
||||||
|
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def send_message(
|
||||||
|
cls,
|
||||||
|
bot: CQHttp,
|
||||||
|
message_chain: MessageChain,
|
||||||
|
event: Event | None = None,
|
||||||
|
is_group: bool = False,
|
||||||
|
session_id: str = None,
|
||||||
|
):
|
||||||
|
"""发送消息"""
|
||||||
|
|
||||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||||
send_one_by_one = any(
|
send_one_by_one = any(
|
||||||
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
|
||||||
)
|
)
|
||||||
if send_one_by_one:
|
if not send_one_by_one:
|
||||||
for seg in message.chain:
|
ret = await cls._parse_onebot_json(message_chain)
|
||||||
if isinstance(seg, (Node, Nodes)):
|
|
||||||
# 合并转发消息
|
|
||||||
|
|
||||||
if isinstance(seg, Node):
|
|
||||||
nodes = Nodes([seg])
|
|
||||||
seg = nodes
|
|
||||||
|
|
||||||
payload = await seg.to_dict()
|
|
||||||
|
|
||||||
if self.get_group_id():
|
|
||||||
payload["group_id"] = self.get_group_id()
|
|
||||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
|
||||||
else:
|
|
||||||
payload["user_id"] = self.get_sender_id()
|
|
||||||
await self.bot.call_action(
|
|
||||||
"send_private_forward_msg", **payload
|
|
||||||
)
|
|
||||||
elif isinstance(seg, File):
|
|
||||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
|
||||||
await self.bot.send(
|
|
||||||
self.message_obj.raw_message,
|
|
||||||
[d],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self.bot.send(
|
|
||||||
self.message_obj.raw_message,
|
|
||||||
await AiocqhttpMessageEvent._parse_onebot_json(
|
|
||||||
MessageChain([seg])
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
else:
|
|
||||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
|
||||||
if not ret:
|
if not ret:
|
||||||
return
|
return
|
||||||
await self.bot.send(self.message_obj.raw_message, ret)
|
await cls._dispatch_send(bot, event, is_group, session_id, ret)
|
||||||
|
return
|
||||||
|
for seg in message_chain.chain:
|
||||||
|
if isinstance(seg, (Node, Nodes)):
|
||||||
|
# 合并转发消息
|
||||||
|
if isinstance(seg, Node):
|
||||||
|
nodes = Nodes([seg])
|
||||||
|
seg = nodes
|
||||||
|
|
||||||
|
payload = await seg.to_dict()
|
||||||
|
|
||||||
|
if is_group:
|
||||||
|
payload["group_id"] = session_id
|
||||||
|
await bot.call_action("send_group_forward_msg", **payload)
|
||||||
|
else:
|
||||||
|
payload["user_id"] = session_id
|
||||||
|
await bot.call_action("send_private_forward_msg", **payload)
|
||||||
|
elif isinstance(seg, File):
|
||||||
|
d = await cls._from_segment_to_dict(seg)
|
||||||
|
await cls._dispatch_send(bot, event, is_group, session_id, [d])
|
||||||
|
else:
|
||||||
|
messages = await cls._parse_onebot_json(MessageChain([seg]))
|
||||||
|
if not messages:
|
||||||
|
continue
|
||||||
|
await cls._dispatch_send(bot, event, is_group, session_id, messages)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
"""发送消息"""
|
||||||
|
event = self.message_obj.raw_message
|
||||||
|
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
|
||||||
|
is_group = False
|
||||||
|
if self.get_group_id():
|
||||||
|
is_group = True
|
||||||
|
session_id = self.get_group_id()
|
||||||
|
else:
|
||||||
|
session_id = self.get_sender_id()
|
||||||
|
await self.send_message(
|
||||||
|
bot=self.bot,
|
||||||
|
message_chain=message,
|
||||||
|
event=event,
|
||||||
|
is_group=is_group,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
async def send_streaming(
|
async def send_streaming(
|
||||||
|
|||||||
@@ -83,19 +83,18 @@ class AiocqhttpAdapter(Platform):
|
|||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
):
|
):
|
||||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
is_group = session.message_type == MessageType.GROUP_MESSAGE
|
||||||
match session.message_type.value:
|
if is_group:
|
||||||
case MessageType.GROUP_MESSAGE.value:
|
session_id = session.session_id.split("_")[-1]
|
||||||
if "_" in session.session_id:
|
else:
|
||||||
# 独立会话
|
session_id = session.session_id
|
||||||
_, group_id = session.session_id.split("_")
|
await AiocqhttpMessageEvent.send_message(
|
||||||
await self.bot.send_group_msg(group_id=group_id, message=ret)
|
bot=self.bot,
|
||||||
else:
|
message_chain=message_chain,
|
||||||
await self.bot.send_group_msg(
|
event=None, # 这里不需要 event,因为是通过 session 发送的
|
||||||
group_id=session.session_id, message=ret
|
is_group=is_group,
|
||||||
)
|
session_id=session_id,
|
||||||
case MessageType.FRIEND_MESSAGE.value:
|
)
|
||||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
|
||||||
await super().send_by_session(session, message_chain)
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||||
@@ -168,9 +167,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
|
|
||||||
if "sub_type" in event:
|
if "sub_type" in event:
|
||||||
if event["sub_type"] == "poke" and "target_id" in event:
|
if event["sub_type"] == "poke" and "target_id" in event:
|
||||||
abm.message.append(
|
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
||||||
Poke(qq=str(event["target_id"]), type="poke")
|
|
||||||
) # noqa: F405
|
|
||||||
|
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
@@ -273,8 +270,16 @@ class AiocqhttpAdapter(Platform):
|
|||||||
action="get_msg",
|
action="get_msg",
|
||||||
message_id=int(m["data"]["id"]),
|
message_id=int(m["data"]["id"]),
|
||||||
)
|
)
|
||||||
|
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
|
||||||
|
reply_event_data["post_type"] = "message"
|
||||||
|
new_event = Event.from_payload(reply_event_data)
|
||||||
|
if not new_event:
|
||||||
|
logger.error(
|
||||||
|
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
abm_reply = await self._convert_handle_message_event(
|
abm_reply = await self._convert_handle_message_event(
|
||||||
Event.from_payload(reply_event_data), get_reply=False
|
new_event, get_reply=False
|
||||||
)
|
)
|
||||||
|
|
||||||
reply_seg = Reply(
|
reply_seg = Reply(
|
||||||
@@ -307,7 +312,9 @@ class AiocqhttpAdapter(Platform):
|
|||||||
user_id=int(m["data"]["qq"]),
|
user_id=int(m["data"]["qq"]),
|
||||||
)
|
)
|
||||||
if at_info:
|
if at_info:
|
||||||
nickname = at_info.get("nick", "")
|
nickname = at_info.get("nick", "") or at_info.get(
|
||||||
|
"nickname", ""
|
||||||
|
)
|
||||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||||
|
|
||||||
abm.message.append(
|
abm.message.append(
|
||||||
@@ -322,7 +329,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
first_at_self_processed = True
|
first_at_self_processed = True
|
||||||
else:
|
else:
|
||||||
# 非第一个@机器人或@其他用户,添加到message_str
|
# 非第一个@机器人或@其他用户,添加到message_str
|
||||||
message_str += f" @{nickname} "
|
message_str += f" @{nickname}({m['data']['qq']}) "
|
||||||
else:
|
else:
|
||||||
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||||
except ActionFailed as e:
|
except ActionFailed as e:
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|||||||
logger.error(f"钉钉图片处理失败: {e}")
|
logger.error(f"钉钉图片处理失败: {e}")
|
||||||
logger.warning(f"跳过图片发送: {image_path}")
|
logger.warning(f"跳过图片发送: {image_path}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
await self.send_with_client(self.client, message)
|
await self.send_with_client(self.client, message)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
|
|||||||
await self.on_ready_once_callback()
|
await self.on_ready_once_callback()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
def _create_message_data(self, message: discord.Message) -> dict:
|
def _create_message_data(self, message: discord.Message) -> dict:
|
||||||
"""从 discord.Message 创建数据字典"""
|
"""从 discord.Message 创建数据字典"""
|
||||||
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
|
|||||||
message_data = self._create_message_data(message)
|
message_data = self._create_message_data(message)
|
||||||
await self.on_message_received(message_data)
|
await self.on_message_received(message_data)
|
||||||
|
|
||||||
|
|
||||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||||
"""从交互中提取内容"""
|
"""从交互中提取内容"""
|
||||||
interaction_type = interaction.type
|
interaction_type = interaction.type
|
||||||
|
|||||||
@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
|
|||||||
self.url = url
|
self.url = url
|
||||||
self.disabled = disabled
|
self.disabled = disabled
|
||||||
|
|
||||||
|
|
||||||
class DiscordReference(BaseMessageComponent):
|
class DiscordReference(BaseMessageComponent):
|
||||||
"""Discord引用组件"""
|
"""Discord引用组件"""
|
||||||
|
|
||||||
type: str = "discord_reference"
|
type: str = "discord_reference"
|
||||||
|
|
||||||
def __init__(self, message_id: str, channel_id: str):
|
def __init__(self, message_id: str, channel_id: str):
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.channel_id = channel_id
|
self.channel_id = channel_id
|
||||||
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
|
|||||||
self.components = components or []
|
self.components = components or []
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
|
||||||
def to_discord_view(self) -> discord.ui.View:
|
def to_discord_view(self) -> discord.ui.View:
|
||||||
"""转换为Discord View对象"""
|
"""转换为Discord View对象"""
|
||||||
view = discord.ui.View(timeout=self.timeout)
|
view = discord.ui.View(timeout=self.timeout)
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||||
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
|
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
|
||||||
self.activity_name = self.config.get("discord_activity_name", None)
|
self.activity_name = self.config.get("discord_activity_name", None)
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
self._polling_task = None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
@@ -137,7 +139,8 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
self.client.on_ready_once_callback = callback
|
self.client.on_ready_once_callback = callback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.client.start_polling()
|
self._polling_task = asyncio.create_task(self.client.start_polling())
|
||||||
|
await self.shutdown_event.wait()
|
||||||
except discord.errors.LoginFailure:
|
except discord.errors.LoginFailure:
|
||||||
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
|
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
|
||||||
except discord.errors.ConnectionClosed:
|
except discord.errors.ConnectionClosed:
|
||||||
@@ -162,42 +165,47 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||||
"""将普通消息转换为 AstrBotMessage"""
|
"""将普通消息转换为 AstrBotMessage"""
|
||||||
message: discord.Message = data["message"]
|
message: discord.Message = data["message"]
|
||||||
is_mentioned = data.get("is_mentioned", False)
|
|
||||||
|
|
||||||
content = message.content
|
content = message.content
|
||||||
|
|
||||||
# 如果机器人被@,移除@部分
|
# 如果机器人被@,移除@部分
|
||||||
if (
|
# 剥离 User Mention (<@id>, <@!id>)
|
||||||
is_mentioned
|
if self.client and self.client.user:
|
||||||
and self.client
|
|
||||||
and self.client.user
|
|
||||||
and self.client.user in message.mentions
|
|
||||||
):
|
|
||||||
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
|
|
||||||
mention_str = f"<@{self.client.user.id}>"
|
mention_str = f"<@{self.client.user.id}>"
|
||||||
mention_str_nickname = (
|
mention_str_nickname = f"<@!{self.client.user.id}>"
|
||||||
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
|
|
||||||
)
|
|
||||||
|
|
||||||
if content.startswith(mention_str):
|
if content.startswith(mention_str):
|
||||||
content = content[len(mention_str) :].lstrip()
|
content = content[len(mention_str) :].lstrip()
|
||||||
elif content.startswith(mention_str_nickname):
|
elif content.startswith(mention_str_nickname):
|
||||||
content = content[len(mention_str_nickname) :].lstrip()
|
content = content[len(mention_str_nickname) :].lstrip()
|
||||||
|
|
||||||
abm = AstrBotMessage()
|
# 剥离 Role Mention(bot 拥有的任一角色被提及,<@&role_id>)
|
||||||
|
if (
|
||||||
|
hasattr(message, "role_mentions")
|
||||||
|
and hasattr(message, "guild")
|
||||||
|
and message.guild
|
||||||
|
):
|
||||||
|
bot_member = (
|
||||||
|
message.guild.get_member(self.client.user.id)
|
||||||
|
if self.client and self.client.user
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if bot_member and hasattr(bot_member, "roles"):
|
||||||
|
for role in bot_member.roles:
|
||||||
|
role_mention_str = f"<@&{role.id}>"
|
||||||
|
if content.startswith(role_mention_str):
|
||||||
|
content = content[len(role_mention_str) :].lstrip()
|
||||||
|
break # 只剥离第一个匹配的角色 mention
|
||||||
|
|
||||||
|
abm = AstrBotMessage()
|
||||||
abm.type = self._get_message_type(message.channel)
|
abm.type = self._get_message_type(message.channel)
|
||||||
abm.group_id = self._get_channel_id(message.channel)
|
abm.group_id = self._get_channel_id(message.channel)
|
||||||
|
|
||||||
abm.message_str = content
|
abm.message_str = content
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
user_id=str(message.author.id), nickname=message.author.display_name
|
user_id=str(message.author.id), nickname=message.author.display_name
|
||||||
)
|
)
|
||||||
|
|
||||||
message_chain = []
|
message_chain = []
|
||||||
if abm.message_str:
|
if abm.message_str:
|
||||||
message_chain.append(Plain(text=abm.message_str))
|
message_chain.append(Plain(text=abm.message_str))
|
||||||
|
|
||||||
if message.attachments:
|
if message.attachments:
|
||||||
for attachment in message.attachments:
|
for attachment in message.attachments:
|
||||||
if attachment.content_type and attachment.content_type.startswith(
|
if attachment.content_type and attachment.content_type.startswith(
|
||||||
@@ -210,7 +218,6 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
message_chain.append(
|
message_chain.append(
|
||||||
File(name=attachment.filename, url=attachment.url)
|
File(name=attachment.filename, url=attachment.url)
|
||||||
)
|
)
|
||||||
|
|
||||||
abm.message = message_chain
|
abm.message = message_chain
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
abm.self_id = self.client_self_id
|
abm.self_id = self.client_self_id
|
||||||
@@ -237,13 +244,35 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
# 检查是否为斜杠指令
|
# 检查是否为斜杠指令
|
||||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||||
|
|
||||||
# 检查是否被@
|
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||||
is_mention = (
|
is_mention = False
|
||||||
|
# User Mention
|
||||||
|
if (
|
||||||
self.client
|
self.client
|
||||||
and self.client.user
|
and self.client.user
|
||||||
and hasattr(message.raw_message, "mentions")
|
and hasattr(message.raw_message, "mentions")
|
||||||
and self.client.user in message.raw_message.mentions
|
):
|
||||||
)
|
if self.client.user in message.raw_message.mentions:
|
||||||
|
is_mention = True
|
||||||
|
# Role Mention(Bot 拥有的角色被提及)
|
||||||
|
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||||
|
bot_member = None
|
||||||
|
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||||
|
try:
|
||||||
|
bot_member = message.raw_message.guild.get_member(
|
||||||
|
self.client.user.id
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
bot_member = None
|
||||||
|
if bot_member and hasattr(bot_member, "roles"):
|
||||||
|
bot_roles = set(bot_member.roles)
|
||||||
|
mentioned_roles = set(message.raw_message.role_mentions)
|
||||||
|
if (
|
||||||
|
bot_roles
|
||||||
|
and mentioned_roles
|
||||||
|
and bot_roles.intersection(mentioned_roles)
|
||||||
|
):
|
||||||
|
is_mention = True
|
||||||
|
|
||||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||||
if is_slash_command or is_mention:
|
if is_slash_command or is_mention:
|
||||||
@@ -255,23 +284,37 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
@override
|
@override
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
"""终止适配器"""
|
"""终止适配器"""
|
||||||
logger.info("[Discord] 正在终止适配器...")
|
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
# 优先 cancel polling_task
|
||||||
|
if self._polling_task:
|
||||||
|
self._polling_task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._polling_task, timeout=10)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("[Discord] polling_task 已取消。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Discord] polling_task 取消异常: {e}")
|
||||||
|
logger.info("[Discord] 正在清理已注册的斜杠指令... (step 2)")
|
||||||
# 清理指令
|
# 清理指令
|
||||||
if self.enable_command_register and self.client:
|
if self.enable_command_register and self.client:
|
||||||
logger.info("[Discord] 正在清理已注册的斜杠指令...")
|
|
||||||
try:
|
try:
|
||||||
# 传入空的列表来清除所有全局指令
|
await asyncio.wait_for(
|
||||||
# 如果指定了 guild_id,则只清除该服务器的指令
|
self.client.sync_commands(
|
||||||
await self.client.sync_commands(
|
commands=[],
|
||||||
commands=[], guild_ids=[self.guild_id] if self.guild_id else None
|
guild_ids=[self.guild_id] if self.guild_id else None,
|
||||||
|
),
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
logger.info("[Discord] 指令清理完成。")
|
logger.info("[Discord] 指令清理完成。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
|
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
|
||||||
|
logger.info("[Discord] 正在关闭 Discord 客户端... (step 3)")
|
||||||
if self.client and hasattr(self.client, "close"):
|
if self.client and hasattr(self.client, "close"):
|
||||||
await self.client.close()
|
try:
|
||||||
|
await asyncio.wait_for(self.client.close(), timeout=10)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Discord] 客户端关闭异常: {e}")
|
||||||
logger.info("[Discord] 适配器已终止。")
|
logger.info("[Discord] 适配器已终止。")
|
||||||
|
|
||||||
def register_handler(self, handler_info):
|
def register_handler(self, handler_info):
|
||||||
|
|||||||
@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 解析消息链为 Discord 所需的对象
|
# 解析消息链为 Discord 所需的对象
|
||||||
try:
|
try:
|
||||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
(
|
||||||
|
content,
|
||||||
|
files,
|
||||||
|
view,
|
||||||
|
embeds,
|
||||||
|
reference_message_id,
|
||||||
|
) = await self._parse_to_discord(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||||
return
|
return
|
||||||
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
if await asyncio.to_thread(path.exists):
|
if await asyncio.to_thread(path.exists):
|
||||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||||
files.append(
|
files.append(
|
||||||
discord.File(BytesIO(file_bytes),
|
discord.File(BytesIO(file_bytes), filename=i.name)
|
||||||
filename=i.name)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -1,812 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import anyio
|
|
||||||
import quart
|
|
||||||
|
|
||||||
from astrbot.api import logger, sp
|
|
||||||
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
|
||||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
|
||||||
from .downloader import GeweDownloader
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .xml_data_parser import GeweDataParser
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
|
||||||
logger.warning(
|
|
||||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleGewechatClient:
|
|
||||||
"""针对 Gewechat 的简单实现。
|
|
||||||
|
|
||||||
@author: Soulter
|
|
||||||
@website: https://github.com/Soulter
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_url: str,
|
|
||||||
nickname: str,
|
|
||||||
host: str,
|
|
||||||
port: int,
|
|
||||||
event_queue: asyncio.Queue,
|
|
||||||
):
|
|
||||||
self.base_url = base_url
|
|
||||||
if self.base_url.endswith("/"):
|
|
||||||
self.base_url = self.base_url[:-1]
|
|
||||||
|
|
||||||
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
|
|
||||||
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
|
|
||||||
|
|
||||||
self.base_url += "/v2/api"
|
|
||||||
|
|
||||||
logger.info(f"Gewechat API: {self.base_url}")
|
|
||||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
|
||||||
|
|
||||||
if isinstance(port, str):
|
|
||||||
port = int(port)
|
|
||||||
|
|
||||||
self.token = None
|
|
||||||
self.headers = {}
|
|
||||||
self.nickname = nickname
|
|
||||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
|
||||||
|
|
||||||
self.server = quart.Quart(__name__)
|
|
||||||
self.server.add_url_rule(
|
|
||||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
|
||||||
)
|
|
||||||
self.server.add_url_rule(
|
|
||||||
"/astrbot-gewechat/file/<file_token>",
|
|
||||||
view_func=self._handle_file,
|
|
||||||
methods=["GET"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
|
||||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
|
||||||
|
|
||||||
self.event_queue = event_queue
|
|
||||||
|
|
||||||
self.multimedia_downloader = None
|
|
||||||
|
|
||||||
self.userrealnames = {}
|
|
||||||
|
|
||||||
self.shutdown_event = asyncio.Event()
|
|
||||||
|
|
||||||
self.staged_files = {}
|
|
||||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
|
||||||
|
|
||||||
self.lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def get_token_id(self):
|
|
||||||
"""获取 Gewechat Token。"""
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
self.token = json_blob["data"]
|
|
||||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
|
||||||
self.headers = {"X-GEWE-TOKEN": self.token}
|
|
||||||
|
|
||||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
|
||||||
if "TypeName" in data:
|
|
||||||
type_name = data["TypeName"]
|
|
||||||
elif "type_name" in data:
|
|
||||||
type_name = data["type_name"]
|
|
||||||
else:
|
|
||||||
raise Exception("无法识别的消息类型")
|
|
||||||
|
|
||||||
# 以下没有业务处理,只是避免控制台打印太多的日志
|
|
||||||
if type_name == "ModContacts":
|
|
||||||
logger.info("gewechat下发:ModContacts消息通知。")
|
|
||||||
return
|
|
||||||
if type_name == "DelContacts":
|
|
||||||
logger.info("gewechat下发:DelContacts消息通知。")
|
|
||||||
return
|
|
||||||
|
|
||||||
if type_name == "Offline":
|
|
||||||
logger.critical("收到 gewechat 下线通知。")
|
|
||||||
return
|
|
||||||
|
|
||||||
d = None
|
|
||||||
if "Data" in data:
|
|
||||||
d = data["Data"]
|
|
||||||
elif "data" in data:
|
|
||||||
d = data["data"]
|
|
||||||
|
|
||||||
if not d:
|
|
||||||
logger.warning(f"消息不含 data 字段: {data}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if "CreateTime" in d:
|
|
||||||
# 得到系统 UTF+8 的 ts
|
|
||||||
tz_offset = datetime.timedelta(hours=8)
|
|
||||||
tz = datetime.timezone(tz_offset)
|
|
||||||
ts = datetime.datetime.now(tz).timestamp()
|
|
||||||
create_time = d["CreateTime"]
|
|
||||||
if create_time < ts - 30:
|
|
||||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
|
||||||
return
|
|
||||||
|
|
||||||
abm = AstrBotMessage()
|
|
||||||
|
|
||||||
from_user_name = d["FromUserName"]["string"] # 消息来源
|
|
||||||
d["to_wxid"] = from_user_name # 用于发信息
|
|
||||||
|
|
||||||
abm.message_id = str(d.get("MsgId"))
|
|
||||||
abm.session_id = from_user_name
|
|
||||||
abm.self_id = data["Wxid"] # 机器人的 wxid
|
|
||||||
|
|
||||||
user_id = "" # 发送人 wxid
|
|
||||||
content = d["Content"]["string"] # 消息内容
|
|
||||||
|
|
||||||
at_me = False
|
|
||||||
at_wxids = []
|
|
||||||
if "@chatroom" in from_user_name:
|
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
|
||||||
_t = content.split(":\n")
|
|
||||||
user_id = _t[0]
|
|
||||||
content = _t[1]
|
|
||||||
# at
|
|
||||||
msg_source = d["MsgSource"]
|
|
||||||
if "\u2005" in content:
|
|
||||||
# at
|
|
||||||
# content = content.split('\u2005')[1]
|
|
||||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
|
||||||
at_wxids = re.findall(
|
|
||||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
|
||||||
msg_source,
|
|
||||||
)
|
|
||||||
|
|
||||||
abm.group_id = from_user_name
|
|
||||||
|
|
||||||
if (
|
|
||||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
|
||||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
|
||||||
):
|
|
||||||
at_me = True
|
|
||||||
if "在群聊中@了你" in d.get("PushContent", ""):
|
|
||||||
at_me = True
|
|
||||||
else:
|
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
|
||||||
user_id = from_user_name
|
|
||||||
|
|
||||||
# 检查消息是否由自己发送,若是则忽略
|
|
||||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
|
||||||
# if user_id == abm.self_id:
|
|
||||||
# logger.info("忽略自己发送的消息")
|
|
||||||
# return None
|
|
||||||
|
|
||||||
abm.message = []
|
|
||||||
|
|
||||||
# 解析用户真实名字
|
|
||||||
user_real_name = "unknown"
|
|
||||||
if abm.group_id:
|
|
||||||
if (
|
|
||||||
abm.group_id not in self.userrealnames
|
|
||||||
or user_id not in self.userrealnames[abm.group_id]
|
|
||||||
):
|
|
||||||
# 获取群成员列表,并且缓存
|
|
||||||
if abm.group_id not in self.userrealnames:
|
|
||||||
self.userrealnames[abm.group_id] = {}
|
|
||||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
|
||||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
|
||||||
if member_list and "memberList" in member_list:
|
|
||||||
for member in member_list["memberList"]:
|
|
||||||
self.userrealnames[abm.group_id][member["wxid"]] = member[
|
|
||||||
"nickName"
|
|
||||||
]
|
|
||||||
if user_id in self.userrealnames[abm.group_id]:
|
|
||||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
|
||||||
else:
|
|
||||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
|
||||||
user_real_name = info["nickName"]
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
|
||||||
user_real_name = user_id
|
|
||||||
|
|
||||||
if at_me:
|
|
||||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
|
||||||
for wxid in at_wxids:
|
|
||||||
# 群聊里 At 其他人的列表
|
|
||||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
|
||||||
abm.message.append(At(qq=wxid, name=_username))
|
|
||||||
|
|
||||||
abm.sender = MessageMember(user_id, user_real_name)
|
|
||||||
abm.raw_message = d
|
|
||||||
abm.message_str = ""
|
|
||||||
|
|
||||||
if user_id == "weixin":
|
|
||||||
# 忽略微信团队消息
|
|
||||||
return
|
|
||||||
|
|
||||||
# 不同消息类型
|
|
||||||
match d["MsgType"]:
|
|
||||||
case 1:
|
|
||||||
# 文本消息
|
|
||||||
abm.message.append(Plain(content))
|
|
||||||
abm.message_str = content
|
|
||||||
case 3:
|
|
||||||
# 图片消息
|
|
||||||
file_url = await self.multimedia_downloader.download_image(
|
|
||||||
self.appid, content
|
|
||||||
)
|
|
||||||
logger.debug(f"下载图片: {file_url}")
|
|
||||||
file_path = await download_image_by_url(file_url)
|
|
||||||
abm.message.append(Image(file=file_path, url=file_path))
|
|
||||||
|
|
||||||
case 34:
|
|
||||||
# 语音消息
|
|
||||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
|
||||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
file_path = os.path.join(
|
|
||||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
|
||||||
)
|
|
||||||
|
|
||||||
async with await anyio.open_file(file_path, "wb") as f:
|
|
||||||
await f.write(voice_data)
|
|
||||||
abm.message.append(Record(file=file_path, url=file_path))
|
|
||||||
|
|
||||||
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
|
||||||
case 37: # 好友申请
|
|
||||||
logger.info("消息类型(37):好友申请")
|
|
||||||
case 42: # 名片
|
|
||||||
logger.info("消息类型(42):名片")
|
|
||||||
case 43: # 视频
|
|
||||||
video = Video(file="", cover=content)
|
|
||||||
abm.message.append(video)
|
|
||||||
case 47: # emoji
|
|
||||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
|
||||||
emoji = data_parser.parse_emoji()
|
|
||||||
abm.message.append(emoji)
|
|
||||||
case 48: # 地理位置
|
|
||||||
logger.info("消息类型(48):地理位置")
|
|
||||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
|
||||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
|
||||||
segments = data_parser.parse_mutil_49()
|
|
||||||
if segments:
|
|
||||||
abm.message.extend(segments)
|
|
||||||
for seg in segments:
|
|
||||||
if isinstance(seg, Plain):
|
|
||||||
abm.message_str += seg.text
|
|
||||||
case 51: # 帐号消息同步?
|
|
||||||
logger.info("消息类型(51):帐号消息同步?")
|
|
||||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
|
||||||
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
|
||||||
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
|
||||||
logger.info(
|
|
||||||
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
|
||||||
)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
|
||||||
abm.raw_message = d
|
|
||||||
|
|
||||||
logger.debug(f"abm: {abm}")
|
|
||||||
return abm
|
|
||||||
|
|
||||||
async def _callback(self):
|
|
||||||
data = await quart.request.json
|
|
||||||
logger.debug(f"收到 gewechat 回调: {data}")
|
|
||||||
|
|
||||||
if data.get("testMsg", None):
|
|
||||||
return quart.jsonify({"r": "AstrBot ACK"})
|
|
||||||
|
|
||||||
abm = None
|
|
||||||
try:
|
|
||||||
abm = await self._convert(data)
|
|
||||||
except BaseException as e:
|
|
||||||
logger.warning(
|
|
||||||
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。"
|
|
||||||
)
|
|
||||||
|
|
||||||
if abm:
|
|
||||||
coro = getattr(self, "on_event_received")
|
|
||||||
if coro:
|
|
||||||
await coro(abm)
|
|
||||||
|
|
||||||
return quart.jsonify({"r": "AstrBot ACK"})
|
|
||||||
|
|
||||||
async def _register_file(self, file_path: str) -> str:
|
|
||||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): 文件路径。
|
|
||||||
Returns:
|
|
||||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
|
||||||
"""
|
|
||||||
async with self.lock:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise Exception(f"文件不存在: {file_path}")
|
|
||||||
|
|
||||||
file_token = str(uuid.uuid4())
|
|
||||||
self.staged_files[file_token] = file_path
|
|
||||||
return file_token
|
|
||||||
|
|
||||||
async def _handle_file(self, file_token):
|
|
||||||
async with self.lock:
|
|
||||||
if file_token not in self.staged_files:
|
|
||||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
|
||||||
return quart.abort(404)
|
|
||||||
if not os.path.exists(self.staged_files[file_token]):
|
|
||||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
|
||||||
return quart.abort(404)
|
|
||||||
file_path = self.staged_files[file_token]
|
|
||||||
self.staged_files.pop(file_token, None)
|
|
||||||
return await quart.send_file(file_path)
|
|
||||||
|
|
||||||
async def _set_callback_url(self):
|
|
||||||
logger.info("设置回调,请等待...")
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/tools/setCallback",
|
|
||||||
headers=self.headers,
|
|
||||||
json={"token": self.token, "callbackUrl": self.callback_url},
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.info(f"设置回调结果: {json_blob}")
|
|
||||||
if json_blob["ret"] != 200:
|
|
||||||
raise Exception(f"设置回调失败: {json_blob}")
|
|
||||||
logger.info(
|
|
||||||
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start_polling(self):
|
|
||||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
|
||||||
await self.server.run_task(
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=self.port,
|
|
||||||
shutdown_trigger=self.shutdown_trigger,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown_trigger(self):
|
|
||||||
await self.shutdown_event.wait()
|
|
||||||
|
|
||||||
async def check_online(self, appid: str):
|
|
||||||
"""检查 APPID 对应的设备是否在线。"""
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/login/checkOnline",
|
|
||||||
headers=self.headers,
|
|
||||||
json={"appId": appid},
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
return json_blob["data"]
|
|
||||||
|
|
||||||
async def logout(self):
|
|
||||||
"""登出 gewechat。"""
|
|
||||||
if self.appid:
|
|
||||||
online = await self.check_online(self.appid)
|
|
||||||
if online:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/login/logout",
|
|
||||||
headers=self.headers,
|
|
||||||
json={"appId": self.appid},
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.info(f"登出结果: {json_blob}")
|
|
||||||
|
|
||||||
async def login(self):
|
|
||||||
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
|
||||||
if self.token is None:
|
|
||||||
await self.get_token_id()
|
|
||||||
|
|
||||||
self.multimedia_downloader = GeweDownloader(
|
|
||||||
self.base_url, self.download_base_url, self.token
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.appid:
|
|
||||||
try:
|
|
||||||
online = await self.check_online(self.appid)
|
|
||||||
if online:
|
|
||||||
logger.info(f"APPID: {self.appid} 已在线")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"检查在线状态失败: {e}")
|
|
||||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
|
||||||
self.appid = None
|
|
||||||
|
|
||||||
payload = {"appId": self.appid}
|
|
||||||
|
|
||||||
if self.appid:
|
|
||||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/login/getLoginQrCode",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
if json_blob["ret"] != 200:
|
|
||||||
error_msg = json_blob.get("data", {}).get("msg", "")
|
|
||||||
if "设备不存在" in error_msg:
|
|
||||||
logger.error(
|
|
||||||
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
|
||||||
)
|
|
||||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
|
||||||
self.appid = None
|
|
||||||
return await self.login()
|
|
||||||
else:
|
|
||||||
raise Exception(f"获取二维码失败: {json_blob}")
|
|
||||||
qr_data = json_blob["data"]["qrData"]
|
|
||||||
qr_uuid = json_blob["data"]["uuid"]
|
|
||||||
appid = json_blob["data"]["appId"]
|
|
||||||
logger.info(f"APPID: {appid}")
|
|
||||||
logger.warning(
|
|
||||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# 执行登录
|
|
||||||
retry_cnt = 64
|
|
||||||
payload.update({"uuid": qr_uuid, "appId": appid})
|
|
||||||
while retry_cnt > 0:
|
|
||||||
retry_cnt -= 1
|
|
||||||
|
|
||||||
# 需要验证码
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
|
||||||
if os.path.exists(code_file_path):
|
|
||||||
with open(code_file_path, "r") as f:
|
|
||||||
code = f.read().strip()
|
|
||||||
if not code:
|
|
||||||
logger.warning(
|
|
||||||
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
continue
|
|
||||||
payload["captchCode"] = code
|
|
||||||
logger.info(f"使用验证码: {code}")
|
|
||||||
try:
|
|
||||||
os.remove(code_file_path)
|
|
||||||
except Exception:
|
|
||||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/login/checkLogin",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.info(f"检查登录状态: {json_blob}")
|
|
||||||
|
|
||||||
ret = json_blob["ret"]
|
|
||||||
msg = ""
|
|
||||||
if json_blob["data"] and "msg" in json_blob["data"]:
|
|
||||||
msg = json_blob["data"]["msg"]
|
|
||||||
if ret == 500 and "安全验证码" in msg:
|
|
||||||
logger.warning(
|
|
||||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if "status" in json_blob["data"]:
|
|
||||||
status = json_blob["data"]["status"]
|
|
||||||
nickname = json_blob["data"].get("nickName", "")
|
|
||||||
if status == 1:
|
|
||||||
logger.info(f"等待确认...{nickname}")
|
|
||||||
elif status == 2:
|
|
||||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
|
||||||
break
|
|
||||||
elif status == 0:
|
|
||||||
logger.info("等待扫码...")
|
|
||||||
else:
|
|
||||||
logger.warning(f"未知状态: {status}")
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
if appid:
|
|
||||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
|
||||||
self.appid = appid
|
|
||||||
logger.info(f"已保存 APPID: {appid}")
|
|
||||||
|
|
||||||
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
|
||||||
"""获取群成员列表。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
|
||||||
"""
|
|
||||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/group/getChatroomMemberList",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
return json_blob["data"]
|
|
||||||
|
|
||||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
|
||||||
"""发送纯文本消息"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
if ats:
|
|
||||||
payload["ats"] = ats
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/postText", headers=self.headers, json=payload
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"发送消息结果: {json_blob}")
|
|
||||||
|
|
||||||
async def post_image(self, to_wxid, image_url: str):
|
|
||||||
"""发送图片消息"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"imgUrl": image_url,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"发送图片结果: {json_blob}")
|
|
||||||
|
|
||||||
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
|
||||||
"""发送emoji消息"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"emojiMd5": emoji_md5,
|
|
||||||
"emojiSize": emoji_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
|
||||||
try:
|
|
||||||
if emoji_md5 != "" and emoji_size != "":
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/postEmoji",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.info(
|
|
||||||
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self.post_image(to_wxid, cdnurl)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
async def post_video(
|
|
||||||
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
|
||||||
):
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"videoUrl": video_url,
|
|
||||||
"thumbUrl": thumb_url,
|
|
||||||
"videoDuration": video_duration,
|
|
||||||
}
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"发送视频结果: {json_blob}")
|
|
||||||
|
|
||||||
async def forward_video(self, to_wxid, cnd_xml: str):
|
|
||||||
"""转发视频
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_wxid (str): 发送给谁
|
|
||||||
cnd_xml (str): 视频消息的cdn信息
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"xml": cnd_xml,
|
|
||||||
}
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/forwardVideo",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"转发视频结果: {json_blob}")
|
|
||||||
|
|
||||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
|
||||||
"""发送语音信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
voice_url (str): 语音文件的网络链接
|
|
||||||
voice_duration (int): 语音时长,毫秒
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"voiceUrl": voice_url,
|
|
||||||
"voiceDuration": voice_duration,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"发送语音: {payload}")
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
|
||||||
|
|
||||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
|
||||||
"""发送文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_wxid (string): 微信ID
|
|
||||||
file_url (str): 文件的网络链接
|
|
||||||
file_name (str): 文件名
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"toWxid": to_wxid,
|
|
||||||
"fileUrl": file_url,
|
|
||||||
"fileName": file_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"发送文件结果: {json_blob}")
|
|
||||||
|
|
||||||
async def add_friend(self, v3: str, v4: str, content: str):
|
|
||||||
"""申请添加好友"""
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"scene": 3,
|
|
||||||
"content": content,
|
|
||||||
"v4": v4,
|
|
||||||
"v3": v3,
|
|
||||||
"option": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/contacts/addContacts",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"申请添加好友结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
|
|
||||||
async def get_group(self, group_id: str):
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"chatroomId": group_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/group/getChatroomInfo",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"获取群信息结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
|
|
||||||
async def get_group_member(self, group_id: str):
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"chatroomId": group_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/group/getChatroomMemberList",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"获取群信息结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
|
|
||||||
async def accept_group_invite(self, url: str):
|
|
||||||
"""同意进群"""
|
|
||||||
payload = {"appId": self.appid, "url": url}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/group/agreeJoinRoom",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"获取群信息结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
|
|
||||||
async def add_group_member_to_friend(
|
|
||||||
self, group_id: str, to_wxid: str, content: str
|
|
||||||
):
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"chatroomId": group_id,
|
|
||||||
"content": content,
|
|
||||||
"memberWxid": to_wxid,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/group/addGroupMemberAsFriend",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"获取群信息结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
|
|
||||||
async def get_user_or_group_info(self, *ids):
|
|
||||||
"""
|
|
||||||
获取用户或群组信息。
|
|
||||||
|
|
||||||
:param ids: 可变数量的 wxid 参数
|
|
||||||
"""
|
|
||||||
|
|
||||||
wxids_str = list(ids)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"appId": self.appid,
|
|
||||||
"wxids": wxids_str, # 使用逗号分隔的字符串
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/contacts/getDetailInfo",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"获取群信息结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
|
|
||||||
async def get_contacts_list(self):
|
|
||||||
"""
|
|
||||||
获取通讯录列表
|
|
||||||
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
|
||||||
"""
|
|
||||||
payload = {"appId": self.appid}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/contacts/fetchContactsList",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
) as resp:
|
|
||||||
json_blob = await resp.json()
|
|
||||||
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
|
||||||
return json_blob
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
from astrbot import logger
|
|
||||||
import aiohttp
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class GeweDownloader:
|
|
||||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
|
||||||
self.base_url = base_url
|
|
||||||
self.download_base_url = download_base_url
|
|
||||||
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
|
|
||||||
|
|
||||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f"{baseurl}{route}", headers=self.headers, json=payload
|
|
||||||
) as resp:
|
|
||||||
return await resp.read()
|
|
||||||
|
|
||||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
|
||||||
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
|
|
||||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
|
||||||
|
|
||||||
async def download_image(self, appid: str, xml: str) -> str:
|
|
||||||
"""返回一个可下载的 URL"""
|
|
||||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
|
||||||
|
|
||||||
for choice in choices:
|
|
||||||
try:
|
|
||||||
payload = {"appId": appid, "xml": xml, "type": choice}
|
|
||||||
data = await self._post_json(
|
|
||||||
self.base_url, "/message/downloadImage", payload
|
|
||||||
)
|
|
||||||
json_blob = json.loads(data)
|
|
||||||
if "fileUrl" in json_blob["data"]:
|
|
||||||
return self.download_base_url + json_blob["data"]["fileUrl"]
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(f"gewe download image: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
raise Exception("无法下载图片")
|
|
||||||
|
|
||||||
async def download_emoji_md5(self, app_id, emoji_md5):
|
|
||||||
"""下载emoji"""
|
|
||||||
try:
|
|
||||||
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
|
||||||
|
|
||||||
# gewe 计划中的接口,暂时没有实现。返回代码404
|
|
||||||
data = await self._post_json(
|
|
||||||
self.base_url, "/message/downloadEmojiMd5", payload
|
|
||||||
)
|
|
||||||
json_blob = json.loads(data)
|
|
||||||
return json_blob
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(f"gewe download emoji: {e}")
|
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import re
|
|
||||||
import wave
|
|
||||||
import uuid
|
|
||||||
import traceback
|
|
||||||
import os
|
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
from astrbot.core.utils.io import download_file
|
|
||||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
|
||||||
from astrbot.api.message_components import (
|
|
||||||
Plain,
|
|
||||||
Image,
|
|
||||||
Record,
|
|
||||||
At,
|
|
||||||
File,
|
|
||||||
Video,
|
|
||||||
WechatEmoji as Emoji,
|
|
||||||
)
|
|
||||||
from .client import SimpleGewechatClient
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_wav_duration(file_path):
|
|
||||||
with wave.open(file_path, "rb") as wav_file:
|
|
||||||
file_size = os.path.getsize(file_path)
|
|
||||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
|
||||||
if n_frames == 2147483647:
|
|
||||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
|
||||||
elif n_frames == 0:
|
|
||||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
|
||||||
else:
|
|
||||||
duration = n_frames / float(framerate)
|
|
||||||
return duration
|
|
||||||
|
|
||||||
|
|
||||||
class GewechatPlatformEvent(AstrMessageEvent):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_str: str,
|
|
||||||
message_obj: AstrBotMessage,
|
|
||||||
platform_meta: PlatformMetadata,
|
|
||||||
session_id: str,
|
|
||||||
client: SimpleGewechatClient,
|
|
||||||
):
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
||||||
self.client = client
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def send_with_client(
|
|
||||||
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
|
||||||
):
|
|
||||||
if not to_wxid:
|
|
||||||
logger.error("无法获取到 to_wxid。")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查@
|
|
||||||
ats = []
|
|
||||||
ats_names = []
|
|
||||||
for comp in message.chain:
|
|
||||||
if isinstance(comp, At):
|
|
||||||
ats.append(comp.qq)
|
|
||||||
ats_names.append(comp.name)
|
|
||||||
has_at = False
|
|
||||||
|
|
||||||
for comp in message.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
text = comp.text
|
|
||||||
payload = {
|
|
||||||
"to_wxid": to_wxid,
|
|
||||||
"content": text,
|
|
||||||
}
|
|
||||||
if not has_at and ats:
|
|
||||||
ats = f"{','.join(ats)}"
|
|
||||||
ats_names = f"@{' @'.join(ats_names)}"
|
|
||||||
text = f"{ats_names} {text}"
|
|
||||||
payload["content"] = text
|
|
||||||
payload["ats"] = ats
|
|
||||||
has_at = True
|
|
||||||
await client.post_text(**payload)
|
|
||||||
|
|
||||||
elif isinstance(comp, Image):
|
|
||||||
img_path = await comp.convert_to_file_path()
|
|
||||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
|
||||||
token = await client._register_file(img_path)
|
|
||||||
img_url = f"{client.file_server_url}/{token}"
|
|
||||||
logger.debug(f"gewe callback img url: {img_url}")
|
|
||||||
await client.post_image(to_wxid, img_url)
|
|
||||||
elif isinstance(comp, Video):
|
|
||||||
if comp.cover != "":
|
|
||||||
await client.forward_video(to_wxid, comp.cover)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from pyffmpeg import FFmpeg
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
logger.error(
|
|
||||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
|
||||||
)
|
|
||||||
raise ModuleNotFoundError(
|
|
||||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
|
||||||
)
|
|
||||||
|
|
||||||
video_url = comp.file
|
|
||||||
# 根据 url 下载视频
|
|
||||||
if video_url.startswith("http"):
|
|
||||||
video_filename = f"{uuid.uuid4()}.mp4"
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
video_path = os.path.join(temp_dir, video_filename)
|
|
||||||
await download_file(video_url, video_path)
|
|
||||||
else:
|
|
||||||
video_path = video_url
|
|
||||||
|
|
||||||
video_token = await client._register_file(video_path)
|
|
||||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
|
||||||
|
|
||||||
# 获取视频第一帧
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
thumb_path = os.path.join(
|
|
||||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
|
||||||
)
|
|
||||||
|
|
||||||
video_path = video_path.replace(" ", "\\ ")
|
|
||||||
try:
|
|
||||||
ff = FFmpeg()
|
|
||||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
|
||||||
ff.options(command)
|
|
||||||
thumb_token = await client._register_file(thumb_path)
|
|
||||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取视频第一帧失败: {e}")
|
|
||||||
|
|
||||||
# 获取视频时长
|
|
||||||
try:
|
|
||||||
from pyffmpeg import FFprobe
|
|
||||||
|
|
||||||
# 创建 FFprobe 实例
|
|
||||||
ffprobe = FFprobe(video_url)
|
|
||||||
# 获取时长字符串
|
|
||||||
duration_str = ffprobe.duration
|
|
||||||
# 处理时长字符串
|
|
||||||
video_duration = float(duration_str.replace(":", ""))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取时长失败: {e}")
|
|
||||||
video_duration = 10
|
|
||||||
|
|
||||||
# 发送视频
|
|
||||||
await client.post_video(
|
|
||||||
to_wxid, video_callback_url, thumb_url, video_duration
|
|
||||||
)
|
|
||||||
|
|
||||||
# 删除临时缩略图文件
|
|
||||||
if os.path.exists(thumb_path):
|
|
||||||
os.remove(thumb_path)
|
|
||||||
elif isinstance(comp, Record):
|
|
||||||
# 默认已经存在 data/temp 中
|
|
||||||
record_url = comp.file
|
|
||||||
record_path = await comp.convert_to_file_path()
|
|
||||||
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
|
||||||
try:
|
|
||||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
|
||||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
|
||||||
if duration == 0:
|
|
||||||
duration = get_wav_duration(record_path)
|
|
||||||
token = await client._register_file(silk_path)
|
|
||||||
record_url = f"{client.file_server_url}/{token}"
|
|
||||||
logger.debug(f"gewe callback record url: {record_url}")
|
|
||||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
|
||||||
elif isinstance(comp, File):
|
|
||||||
file_path = comp.file
|
|
||||||
file_name = comp.name
|
|
||||||
if file_path.startswith("file:///"):
|
|
||||||
file_path = file_path[8:]
|
|
||||||
elif file_path.startswith("http"):
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
||||||
temp_file_path = os.path.join(temp_dir, file_name)
|
|
||||||
await download_file(file_path, temp_file_path)
|
|
||||||
file_path = temp_file_path
|
|
||||||
else:
|
|
||||||
file_path = file_path
|
|
||||||
|
|
||||||
token = await client._register_file(file_path)
|
|
||||||
file_url = f"{client.file_server_url}/{token}"
|
|
||||||
logger.debug(f"gewe callback file url: {file_url}")
|
|
||||||
await client.post_file(to_wxid, file_url, file_name)
|
|
||||||
elif isinstance(comp, Emoji):
|
|
||||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
|
||||||
elif isinstance(comp, At):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
|
||||||
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
|
||||||
await super().send(message)
|
|
||||||
|
|
||||||
async def get_group(self, group_id=None, **kwargs):
|
|
||||||
# 确定有效的 group_id
|
|
||||||
if group_id is None:
|
|
||||||
group_id = self.get_group_id()
|
|
||||||
|
|
||||||
if not group_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
res = await self.client.get_group(group_id)
|
|
||||||
data: dict = res["data"]
|
|
||||||
|
|
||||||
if not data["chatroomId"]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
members = [
|
|
||||||
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
|
||||||
for member in data.get("memberList", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
return Group(
|
|
||||||
group_id=data["chatroomId"],
|
|
||||||
group_name=data.get("nickName"),
|
|
||||||
group_avatar=data.get("smallHeadImgUrl"),
|
|
||||||
group_owner=data.get("chatRoomOwner"),
|
|
||||||
members=members,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_streaming(
|
|
||||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
|
||||||
):
|
|
||||||
if not use_fallback:
|
|
||||||
buffer = None
|
|
||||||
async for chain in generator:
|
|
||||||
if not buffer:
|
|
||||||
buffer = chain
|
|
||||||
else:
|
|
||||||
buffer.chain.extend(chain.chain)
|
|
||||||
if not buffer:
|
|
||||||
return
|
|
||||||
buffer.squash_plain()
|
|
||||||
await self.send(buffer)
|
|
||||||
return await super().send_streaming(generator, use_fallback)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
|
||||||
|
|
||||||
async for chain in generator:
|
|
||||||
if isinstance(chain, MessageChain):
|
|
||||||
for comp in chain.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
buffer += comp.text
|
|
||||||
if any(p in buffer for p in "。?!~…"):
|
|
||||||
buffer = await self.process_buffer(buffer, pattern)
|
|
||||||
else:
|
|
||||||
await self.send(MessageChain(chain=[comp]))
|
|
||||||
await asyncio.sleep(1.5) # 限速
|
|
||||||
|
|
||||||
if buffer.strip():
|
|
||||||
await self.send(MessageChain([Plain(buffer)]))
|
|
||||||
return await super().send_streaming(generator, use_fallback)
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
|
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
|
||||||
from ...register import register_platform_adapter
|
|
||||||
from .gewechat_event import GewechatPlatformEvent
|
|
||||||
from .client import SimpleGewechatClient
|
|
||||||
from astrbot import logger
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
from typing import override
|
|
||||||
else:
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
|
||||||
class GewechatPlatformAdapter(Platform):
|
|
||||||
def __init__(
|
|
||||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
|
||||||
) -> None:
|
|
||||||
super().__init__(event_queue)
|
|
||||||
self.config = platform_config
|
|
||||||
self.settingss = platform_settings
|
|
||||||
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
|
||||||
self.client = None
|
|
||||||
|
|
||||||
self.client = SimpleGewechatClient(
|
|
||||||
self.config["base_url"],
|
|
||||||
self.config["nickname"],
|
|
||||||
self.config["host"],
|
|
||||||
self.config["port"],
|
|
||||||
self._event_queue,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_event_received(abm: AstrBotMessage):
|
|
||||||
await self.handle_msg(abm)
|
|
||||||
|
|
||||||
self.client.on_event_received = on_event_received
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def send_by_session(
|
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
|
||||||
):
|
|
||||||
session_id = session.session_id
|
|
||||||
if "#" in session_id:
|
|
||||||
# unique session
|
|
||||||
to_wxid = session_id.split("#")[1]
|
|
||||||
else:
|
|
||||||
to_wxid = session_id
|
|
||||||
|
|
||||||
await GewechatPlatformEvent.send_with_client(
|
|
||||||
message_chain, to_wxid, self.client
|
|
||||||
)
|
|
||||||
|
|
||||||
await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
return PlatformMetadata(
|
|
||||||
name="gewechat",
|
|
||||||
description="基于 gewechat 的 Wechat 适配器",
|
|
||||||
id=self.config.get("id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
self.client.shutdown_event.set()
|
|
||||||
try:
|
|
||||||
await self.client.server.shutdown()
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
logger.info("Gewechat 适配器已被优雅地关闭。")
|
|
||||||
|
|
||||||
async def logout(self):
|
|
||||||
await self.client.logout()
|
|
||||||
|
|
||||||
@override
|
|
||||||
def run(self):
|
|
||||||
return self._run()
|
|
||||||
|
|
||||||
async def _run(self):
|
|
||||||
await self.client.login()
|
|
||||||
await self.client.start_polling()
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
|
||||||
if message.type == MessageType.GROUP_MESSAGE:
|
|
||||||
if self.settingss["unique_session"]:
|
|
||||||
message.session_id = message.sender.user_id + "#" + message.group_id
|
|
||||||
|
|
||||||
message_event = GewechatPlatformEvent(
|
|
||||||
message_str=message.message_str,
|
|
||||||
message_obj=message,
|
|
||||||
platform_meta=self.meta(),
|
|
||||||
session_id=message.session_id,
|
|
||||||
client=self.client,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.commit_event(message_event)
|
|
||||||
|
|
||||||
def get_client(self) -> SimpleGewechatClient:
|
|
||||||
return self.client
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
from defusedxml import ElementTree as eT
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.message_components import (
|
|
||||||
WechatEmoji as Emoji,
|
|
||||||
Reply,
|
|
||||||
Plain,
|
|
||||||
BaseMessageComponent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GeweDataParser:
|
|
||||||
def __init__(self, data, is_private_chat):
|
|
||||||
self.data = data
|
|
||||||
self.is_private_chat = is_private_chat
|
|
||||||
|
|
||||||
def _format_to_xml(self):
|
|
||||||
return eT.fromstring(self.data)
|
|
||||||
|
|
||||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
|
||||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
|
||||||
if appmsg_type is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
match appmsg_type.text:
|
|
||||||
case "57":
|
|
||||||
return self.parse_reply()
|
|
||||||
|
|
||||||
def parse_emoji(self) -> Emoji | None:
|
|
||||||
try:
|
|
||||||
emoji_element = self._format_to_xml().find(".//emoji")
|
|
||||||
# 提取 md5 和 len 属性
|
|
||||||
if emoji_element is not None:
|
|
||||||
md5_value = emoji_element.get("md5")
|
|
||||||
emoji_size = emoji_element.get("len")
|
|
||||||
cdnurl = emoji_element.get("cdnurl")
|
|
||||||
|
|
||||||
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
|
||||||
|
|
||||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
|
||||||
"""解析引用消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
replied_id = -1
|
|
||||||
replied_uid = 0
|
|
||||||
replied_nickname = ""
|
|
||||||
replied_content = "" # 被引用者说的内容
|
|
||||||
content = "" # 引用者说的内容
|
|
||||||
|
|
||||||
root = self._format_to_xml()
|
|
||||||
refermsg = root.find(".//refermsg")
|
|
||||||
if refermsg is not None:
|
|
||||||
# 被引用的信息
|
|
||||||
svrid = refermsg.find("svrid")
|
|
||||||
fromusr = refermsg.find("fromusr")
|
|
||||||
displayname = refermsg.find("displayname")
|
|
||||||
refermsg_content = refermsg.find("content")
|
|
||||||
if svrid is not None:
|
|
||||||
replied_id = svrid.text
|
|
||||||
if fromusr is not None:
|
|
||||||
replied_uid = fromusr.text
|
|
||||||
if displayname is not None:
|
|
||||||
replied_nickname = displayname.text
|
|
||||||
if refermsg_content is not None:
|
|
||||||
# 处理引用嵌套,包括嵌套公众号消息
|
|
||||||
if refermsg_content.text.startswith(
|
|
||||||
"<msg>"
|
|
||||||
) or refermsg_content.text.startswith("<?xml"):
|
|
||||||
try:
|
|
||||||
logger.debug("gewechat: Reference message is nested")
|
|
||||||
refer_root = eT.fromstring(refermsg_content.text)
|
|
||||||
img = refer_root.find("img")
|
|
||||||
if img is not None:
|
|
||||||
replied_content = "[图片]"
|
|
||||||
else:
|
|
||||||
app_msg = refer_root.find("appmsg")
|
|
||||||
refermsg_content_title = app_msg.find("title")
|
|
||||||
logger.debug(
|
|
||||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
|
||||||
)
|
|
||||||
replied_content = refermsg_content_title.text
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"gewechat: nested failed, {e}")
|
|
||||||
# 处理异常情况
|
|
||||||
replied_content = refermsg_content.text
|
|
||||||
else:
|
|
||||||
replied_content = refermsg_content.text
|
|
||||||
|
|
||||||
# 提取引用者说的内容
|
|
||||||
title = root.find(".//appmsg/title")
|
|
||||||
if title is not None:
|
|
||||||
content = title.text
|
|
||||||
|
|
||||||
reply_seg = Reply(
|
|
||||||
id=replied_id,
|
|
||||||
chain=[Plain(replied_content)],
|
|
||||||
sender_id=replied_uid,
|
|
||||||
sender_nickname=replied_nickname,
|
|
||||||
message_str=replied_content,
|
|
||||||
)
|
|
||||||
plain_seg = Plain(content)
|
|
||||||
return [reply_seg, plain_seg]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
|
||||||
@@ -28,10 +28,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
self.send_buffer = None
|
self.send_buffer = None
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
if not self.send_buffer:
|
self.send_buffer = message
|
||||||
self.send_buffer = message
|
await self._post_send()
|
||||||
else:
|
|
||||||
self.send_buffer.chain.extend(message.chain)
|
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
"""流式输出仅支持消息列表私聊"""
|
"""流式输出仅支持消息列表私聊"""
|
||||||
|
|||||||
@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
|
|||||||
base64_content = base64.b64encode(content).decode("utf-8")
|
base64_content = base64.b64encode(content).decode("utf-8")
|
||||||
return base64_content
|
return base64_content
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
logger.error(
|
||||||
|
f"Failed to download slack file: {resp.status} {await resp.text()}"
|
||||||
|
)
|
||||||
raise Exception(f"下载文件失败: {resp.status}")
|
raise Exception(f"下载文件失败: {resp.status}")
|
||||||
|
|
||||||
async def run(self) -> Awaitable[Any]:
|
async def run(self) -> Awaitable[Any]:
|
||||||
|
|||||||
@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||||
}
|
}
|
||||||
file_url = response["files"][0]["permalink"]
|
file_url = response["files"][0]["permalink"]
|
||||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
return {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||||
|
},
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||||
|
|
||||||
|
|||||||
@@ -40,20 +40,21 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
def _split_message(self, text: str) -> list[str]:
|
@classmethod
|
||||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
def _split_message(cls, text: str) -> list[str]:
|
||||||
|
if len(text) <= cls.MAX_MESSAGE_LENGTH:
|
||||||
return [text]
|
return [text]
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
while text:
|
while text:
|
||||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
if len(text) <= cls.MAX_MESSAGE_LENGTH:
|
||||||
chunks.append(text)
|
chunks.append(text)
|
||||||
break
|
break
|
||||||
|
|
||||||
split_point = self.MAX_MESSAGE_LENGTH
|
split_point = cls.MAX_MESSAGE_LENGTH
|
||||||
segment = text[: self.MAX_MESSAGE_LENGTH]
|
segment = text[: cls.MAX_MESSAGE_LENGTH]
|
||||||
|
|
||||||
for _, pattern in self.SPLIT_PATTERNS.items():
|
for _, pattern in cls.SPLIT_PATTERNS.items():
|
||||||
if matches := list(pattern.finditer(segment)):
|
if matches := list(pattern.finditer(segment)):
|
||||||
last_match = matches[-1]
|
last_match = matches[-1]
|
||||||
split_point = last_match.end()
|
split_point = last_match.end()
|
||||||
@@ -64,8 +65,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
async def send_with_client(
|
async def send_with_client(
|
||||||
self, client: ExtBot, message: MessageChain, user_name: str
|
cls, client: ExtBot, message: MessageChain, user_name: str
|
||||||
):
|
):
|
||||||
image_path = None
|
image_path = None
|
||||||
|
|
||||||
@@ -97,7 +99,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
if at_user_id and not at_flag:
|
if at_user_id and not at_flag:
|
||||||
i.text = f"@{at_user_id} {i.text}"
|
i.text = f"@{at_user_id} {i.text}"
|
||||||
at_flag = True
|
at_flag = True
|
||||||
chunks = self._split_message(i.text)
|
chunks = cls._split_message(i.text)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
try:
|
try:
|
||||||
md_text = telegramify_markdown.markdownify(
|
md_text = telegramify_markdown.markdownify(
|
||||||
@@ -158,6 +160,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
if isinstance(chain, MessageChain):
|
if isinstance(chain, MessageChain):
|
||||||
|
if chain.type == "break":
|
||||||
|
# 分割符
|
||||||
|
message_id = None # 重置消息 ID
|
||||||
|
delta = "" # 重置 delta
|
||||||
|
continue
|
||||||
|
|
||||||
# 处理消息链中的每个组件
|
# 处理消息链中的每个组件
|
||||||
for i in chain.chain:
|
for i in chain.chain:
|
||||||
if isinstance(i, Plain):
|
if isinstance(i, Plain):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any, Callable
|
||||||
from astrbot.core.platform import (
|
from astrbot.core.platform import (
|
||||||
Platform,
|
Platform,
|
||||||
AstrBotMessage,
|
AstrBotMessage,
|
||||||
@@ -13,7 +13,7 @@ from astrbot.core.platform import (
|
|||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
|
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import web_chat_queue
|
from .webchat_queue_mgr import webchat_queue_mgr, WebChatQueueMgr
|
||||||
from .webchat_event import WebChatMessageEvent
|
from .webchat_event import WebChatMessageEvent
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
@@ -21,14 +21,46 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
|
|
||||||
|
|
||||||
class QueueListener:
|
class QueueListener:
|
||||||
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
|
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
|
||||||
self.queue = queue
|
self.webchat_queue_mgr = webchat_queue_mgr
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
self.running_tasks = set()
|
||||||
|
|
||||||
|
async def listen_to_queue(self, conversation_id: str):
|
||||||
|
"""Listen to a specific conversation queue"""
|
||||||
|
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await queue.get()
|
||||||
|
await self.callback(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing message from conversation {conversation_id}: {e}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
"""Monitor for new conversation queues and start listeners"""
|
||||||
|
monitored_conversations = set()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
data = await self.queue.get()
|
# Check for new conversations
|
||||||
await self.callback(data)
|
current_conversations = set(self.webchat_queue_mgr.queues.keys())
|
||||||
|
new_conversations = current_conversations - monitored_conversations
|
||||||
|
|
||||||
|
# Start listeners for new conversations
|
||||||
|
for conversation_id in new_conversations:
|
||||||
|
task = asyncio.create_task(self.listen_to_queue(conversation_id))
|
||||||
|
self.running_tasks.add(task)
|
||||||
|
task.add_done_callback(self.running_tasks.discard)
|
||||||
|
monitored_conversations.add(conversation_id)
|
||||||
|
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||||
|
|
||||||
|
# Clean up monitored conversations that no longer exist
|
||||||
|
removed_conversations = monitored_conversations - current_conversations
|
||||||
|
monitored_conversations -= removed_conversations
|
||||||
|
|
||||||
|
await asyncio.sleep(1) # Check for new conversations every second
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("webchat", "webchat")
|
@register_platform_adapter("webchat", "webchat")
|
||||||
@@ -45,7 +77,7 @@ class WebChatAdapter(Platform):
|
|||||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
name="webchat", description="webchat", id=self.config.get("id")
|
name="webchat", description="webchat", id=self.config.get("id", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
@@ -105,7 +137,7 @@ class WebChatAdapter(Platform):
|
|||||||
abm = await self.convert_message(data)
|
abm = await self.convert_message(data)
|
||||||
await self.handle_msg(abm)
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
bot = QueueListener(web_chat_queue, callback)
|
bot = QueueListener(webchat_queue_mgr, callback)
|
||||||
return bot.run()
|
return bot.run()
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
@@ -119,6 +151,10 @@ class WebChatAdapter(Platform):
|
|||||||
session_id=message.session_id,
|
session_id=message.session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_, _, payload = message.raw_message # type: ignore
|
||||||
|
message_event.set_extra("selected_provider", payload.get("selected_provider"))
|
||||||
|
message_event.set_extra("selected_model", payload.get("selected_model"))
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from astrbot.api import logger
|
|||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image, Record
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core import web_chat_back_queue
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
from .webchat_queue_mgr import webchat_queue_mgr
|
||||||
|
|
||||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||||
|
|
||||||
@@ -18,13 +18,18 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||||
|
cid = session_id.split("!")[-1]
|
||||||
|
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||||
if not message:
|
if not message:
|
||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{"type": "end", "data": "", "streaming": False}
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": "",
|
||||||
|
"streaming": False,
|
||||||
|
} # end means this request is finished
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
cid = session_id.split("!")[-1]
|
|
||||||
data = ""
|
data = ""
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
@@ -35,6 +40,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
"cid": cid,
|
"cid": cid,
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
|
"chain_type": message.type,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
@@ -97,29 +103,35 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||||
await web_chat_back_queue.put(
|
|
||||||
{
|
|
||||||
"type": "end",
|
|
||||||
"data": "",
|
|
||||||
"streaming": False,
|
|
||||||
"cid": self.session_id.split("!")[-1],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
final_data = ""
|
final_data = ""
|
||||||
|
cid = self.session_id.split("!")[-1]
|
||||||
|
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
|
if chain.type == "break" and final_data:
|
||||||
|
# 分割符
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "break", # break means a segment end
|
||||||
|
"data": final_data,
|
||||||
|
"streaming": True,
|
||||||
|
"cid": cid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
final_data = ""
|
||||||
|
continue
|
||||||
final_data += await WebChatMessageEvent._send(
|
final_data += await WebChatMessageEvent._send(
|
||||||
chain, session_id=self.session_id, streaming=True
|
chain, session_id=self.session_id, streaming=True
|
||||||
)
|
)
|
||||||
|
|
||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "end",
|
"type": "complete", # complete means we return the final result
|
||||||
"data": final_data,
|
"data": final_data,
|
||||||
"streaming": True,
|
"streaming": True,
|
||||||
"cid": self.session_id.split("!")[-1],
|
"cid": cid,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await super().send_streaming(generator, use_fallback)
|
await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
35
astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
Normal file
35
astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class WebChatQueueMgr:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.queues = {}
|
||||||
|
"""Conversation ID to asyncio.Queue mapping"""
|
||||||
|
self.back_queues = {}
|
||||||
|
"""Conversation ID to asyncio.Queue mapping for responses"""
|
||||||
|
|
||||||
|
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||||
|
"""Get or create a queue for the given conversation ID"""
|
||||||
|
if conversation_id not in self.queues:
|
||||||
|
self.queues[conversation_id] = asyncio.Queue()
|
||||||
|
return self.queues[conversation_id]
|
||||||
|
|
||||||
|
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||||
|
"""Get or create a back queue for the given conversation ID"""
|
||||||
|
if conversation_id not in self.back_queues:
|
||||||
|
self.back_queues[conversation_id] = asyncio.Queue()
|
||||||
|
return self.back_queues[conversation_id]
|
||||||
|
|
||||||
|
def remove_queues(self, conversation_id: str):
|
||||||
|
"""Remove queues for the given conversation ID"""
|
||||||
|
if conversation_id in self.queues:
|
||||||
|
del self.queues[conversation_id]
|
||||||
|
if conversation_id in self.back_queues:
|
||||||
|
del self.back_queues[conversation_id]
|
||||||
|
|
||||||
|
def has_queue(self, conversation_id: str) -> bool:
|
||||||
|
"""Check if a queue exists for the given conversation ID"""
|
||||||
|
return conversation_id in self.queues
|
||||||
|
|
||||||
|
|
||||||
|
webchat_queue_mgr = WebChatQueueMgr()
|
||||||
@@ -210,6 +210,16 @@ class WeChatPadProAdapter(Platform):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _extract_auth_key(self, data):
|
||||||
|
"""Helper method to extract auth_key from response data."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
auth_keys = data.get("authKeys") # 新接口
|
||||||
|
if isinstance(auth_keys, list) and auth_keys:
|
||||||
|
return auth_keys[0]
|
||||||
|
elif isinstance(data, list) and data: # 旧接口
|
||||||
|
return data[0]
|
||||||
|
return None
|
||||||
|
|
||||||
async def generate_auth_key(self):
|
async def generate_auth_key(self):
|
||||||
"""
|
"""
|
||||||
生成授权码。
|
生成授权码。
|
||||||
@@ -218,28 +228,30 @@ class WeChatPadProAdapter(Platform):
|
|||||||
params = {"key": self.admin_key}
|
params = {"key": self.admin_key}
|
||||||
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
||||||
|
|
||||||
|
self.auth_key = None # Reset auth_key before generating a new one
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
try:
|
try:
|
||||||
async with session.post(url, params=params, json=payload) as response:
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"生成授权码失败: {response.status}, {await response.text()}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
# 修正成功判断条件和授权码提取路径
|
if response_data.get("Code") == 200:
|
||||||
if response.status == 200 and response_data.get("Code") == 200:
|
if data := response_data.get("Data"):
|
||||||
# 授权码在 Data 字段的列表中
|
self.auth_key = self._extract_auth_key(data)
|
||||||
if (
|
|
||||||
response_data.get("Data")
|
if self.auth_key:
|
||||||
and isinstance(response_data["Data"], list)
|
logger.info("成功获取授权码")
|
||||||
and len(response_data["Data"]) > 0
|
|
||||||
):
|
|
||||||
self.auth_key = response_data["Data"][0]
|
|
||||||
logger.info(f"成功获取授权码 {self.auth_key[:8]}...")
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"生成授权码成功但未找到授权码: {response_data}"
|
f"生成授权码成功但未找到授权码: {response_data}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"生成授权码失败: {response_data}")
|
||||||
f"生成授权码失败: {response.status}, {response_data}"
|
|
||||||
)
|
|
||||||
except aiohttp.ClientConnectorError as e:
|
except aiohttp.ClientConnectorError as e:
|
||||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
|||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64
|
from astrbot.core.utils.tencent_record_helper import audio_to_tencent_silk_base64
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .wechatpadpro_adapter import WeChatPadProAdapter
|
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||||
@@ -113,7 +113,7 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
|||||||
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
|
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
|
||||||
record_path = await comp.convert_to_file_path()
|
record_path = await comp.convert_to_file_path()
|
||||||
# 默认已经存在 data/temp 中
|
# 默认已经存在 data/temp 中
|
||||||
b64, duration = await wav_to_tencent_silk_base64(record_path)
|
b64, duration = await audio_to_tencent_silk_base64(record_path)
|
||||||
payload = {
|
payload = {
|
||||||
"ToUserName": self.session_id,
|
"ToUserName": self.session_id,
|
||||||
"VoiceData": b64,
|
"VoiceData": b64,
|
||||||
|
|||||||
@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||||
:return: 接口调用结果
|
:return: 接口调用结果
|
||||||
"""
|
"""
|
||||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
data = {
|
||||||
|
"token": token,
|
||||||
|
"cursor": cursor,
|
||||||
|
"limit": limit,
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
}
|
||||||
return self._post("kf/sync_msg", data=data)
|
return self._post("kf/sync_msg", data=data)
|
||||||
|
|
||||||
def get_service_state(self, open_kfid, external_userid):
|
def get_service_state(self, open_kfid, external_userid):
|
||||||
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
}
|
}
|
||||||
return self._post("kf/service_state/get", data=data)
|
return self._post("kf/service_state/get", data=data)
|
||||||
|
|
||||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
def trans_service_state(
|
||||||
|
self, open_kfid, external_userid, service_state, servicer_userid=""
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
变更会话状态
|
变更会话状态
|
||||||
|
|
||||||
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
"""
|
"""
|
||||||
return self._get("kf/customer/get_upgrade_service_config")
|
return self._get("kf/customer/get_upgrade_service_config")
|
||||||
|
|
||||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
def upgrade_service(
|
||||||
|
self, open_kfid, external_userid, service_type, member=None, groupchat=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
为客户升级为专员或客户群服务
|
为客户升级为专员或客户群服务
|
||||||
|
|
||||||
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||||
return self._post("kf/get_corp_statistic", data=data)
|
return self._post("kf/get_corp_statistic", data=data)
|
||||||
|
|
||||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
def get_servicer_statistic(
|
||||||
|
self, start_time, end_time, open_kfid=None, servicer_userid=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取「客户数据统计」接待人员明细数据
|
获取「客户数据统计」接待人员明细数据
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from optionaldict import optionaldict
|
|||||||
|
|
||||||
from wechatpy.client.api.base import BaseWeChatAPI
|
from wechatpy.client.api.base import BaseWeChatAPI
|
||||||
|
|
||||||
|
|
||||||
class WeChatKFMessage(BaseWeChatAPI):
|
class WeChatKFMessage(BaseWeChatAPI):
|
||||||
"""
|
"""
|
||||||
发送微信客服消息
|
发送微信客服消息
|
||||||
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
|
|||||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
def send_msgmenu(
|
||||||
|
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
|
||||||
|
):
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "msgmenu",
|
"msgtype": "msgmenu",
|
||||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
"msgmenu": {
|
||||||
|
"head_content": head_content,
|
||||||
|
"list": menu_list,
|
||||||
|
"tail_content": tail_content,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
def send_location(
|
||||||
|
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
|
||||||
|
):
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "location",
|
"msgtype": "location",
|
||||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
"msgmenu": {
|
||||||
|
"name": name,
|
||||||
|
"address": address,
|
||||||
|
"latitude": latitude,
|
||||||
|
"longitude": longitude,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
def send_miniprogram(
|
||||||
|
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
|
||||||
|
):
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "miniprogram",
|
"msgtype": "miniprogram",
|
||||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
"msgmenu": {
|
||||||
|
"appid": appid,
|
||||||
|
"title": title,
|
||||||
|
"thumb_media_id": thumb_media_id,
|
||||||
|
"pagepath": pagepath,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
self.wexin_event_workers[msg.id] = future
|
self.wexin_event_workers[msg.id] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.shield(future), 60
|
||||||
|
) # wait for 60s
|
||||||
logger.debug(f"Got future result: {result}")
|
logger.debug(f"Got future result: {result}")
|
||||||
self.wexin_event_workers.pop(msg.id, None)
|
self.wexin_event_workers.pop(msg.id, None)
|
||||||
return result # xml. see weixin_offacc_event.py
|
return result # xml. see weixin_offacc_event.py
|
||||||
|
|||||||
@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
return
|
return
|
||||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||||
|
|
||||||
|
|
||||||
if active_send_mode:
|
if active_send_mode:
|
||||||
self.client.message.send_voice(
|
self.client.message.send_voice(
|
||||||
message_obj.sender.user_id,
|
message_obj.sender.user_id,
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class AssistantMessageSegment:
|
|||||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
content: str = None
|
content: str = None
|
||||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
||||||
role: str = "assistant"
|
role: str = "assistant"
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
|
|||||||
}
|
}
|
||||||
if self.content:
|
if self.content:
|
||||||
ret["content"] = self.content
|
ret["content"] = self.content
|
||||||
elif self.tool_calls:
|
if self.tool_calls:
|
||||||
ret["tool_calls"] = self.tool_calls
|
ret["tool_calls"] = self.tool_calls
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -95,27 +95,38 @@ class ProviderRequest:
|
|||||||
"""提示词"""
|
"""提示词"""
|
||||||
session_id: str = ""
|
session_id: str = ""
|
||||||
"""会话 ID"""
|
"""会话 ID"""
|
||||||
image_urls: List[str] = None
|
image_urls: list[str] = field(default_factory=list)
|
||||||
"""图片 URL 列表"""
|
"""图片 URL 列表"""
|
||||||
func_tool: FuncCall = None
|
func_tool: FuncCall | None = None
|
||||||
"""可用的函数工具"""
|
"""可用的函数工具"""
|
||||||
contexts: List = None
|
contexts: list[dict] = field(default_factory=list)
|
||||||
"""上下文。格式与 openai 的上下文格式一致:
|
"""上下文。格式与 openai 的上下文格式一致:
|
||||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||||
"""
|
"""
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
"""系统提示词"""
|
"""系统提示词"""
|
||||||
conversation: Conversation = None
|
conversation: Conversation | None = None
|
||||||
|
|
||||||
tool_calls_result: ToolCallsResult = None
|
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
|
||||||
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||||
|
|
||||||
|
model: str | None = None
|
||||||
|
"""模型名称,为 None 时使用提供商的默认模型"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
|
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
|
||||||
|
"""添加工具调用结果到请求中"""
|
||||||
|
if not self.tool_calls_result:
|
||||||
|
self.tool_calls_result = []
|
||||||
|
if isinstance(self.tool_calls_result, ToolCallsResult):
|
||||||
|
self.tool_calls_result = [self.tool_calls_result]
|
||||||
|
self.tool_calls_result.append(tool_calls_result)
|
||||||
|
|
||||||
def _print_friendly_context(self):
|
def _print_friendly_context(self):
|
||||||
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||||
if not self.contexts:
|
if not self.contexts:
|
||||||
|
|||||||
@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
|
|||||||
] # json schema 支持的数据类型
|
] # json schema 支持的数据类型
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_config(config: dict) -> dict:
|
||||||
|
"""准备配置,处理嵌套格式"""
|
||||||
|
if "mcpServers" in config and config["mcpServers"]:
|
||||||
|
first_key = next(iter(config["mcpServers"]))
|
||||||
|
config = config["mcpServers"][first_key]
|
||||||
|
config.pop("active", None)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||||
|
"""快速测试 MCP 服务器可达性"""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
cfg = _prepare_config(config.copy())
|
||||||
|
|
||||||
|
url = cfg["url"]
|
||||||
|
headers = cfg.get("headers", {})
|
||||||
|
timeout = cfg.get("timeout", 10)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
if cfg.get("transport") == "streamable_http":
|
||||||
|
test_payload = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "initialize",
|
||||||
|
"id": 0,
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
**headers,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
},
|
||||||
|
json=test_payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return True, ""
|
||||||
|
else:
|
||||||
|
return False, f"HTTP {response.status}: {response.reason}"
|
||||||
|
else:
|
||||||
|
async with session.get(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
**headers,
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return True, ""
|
||||||
|
else:
|
||||||
|
return False, f"HTTP {response.status}: {response.reason}"
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return False, f"连接超时: {timeout}秒"
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"{e!s}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FuncTool:
|
class FuncTool:
|
||||||
"""
|
"""
|
||||||
@@ -80,12 +146,10 @@ class FuncTool:
|
|||||||
if not self.mcp_client or not self.mcp_client.session:
|
if not self.mcp_client or not self.mcp_client.session:
|
||||||
raise Exception(f"MCP client for {self.name} is not available")
|
raise Exception(f"MCP client for {self.name} is not available")
|
||||||
# 使用name属性而不是额外的mcp_tool_name
|
# 使用name属性而不是额外的mcp_tool_name
|
||||||
if ":" in self.name:
|
actual_tool_name = (
|
||||||
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
self.name.split(":")[-1] if ":" in self.name else self.name
|
||||||
actual_tool_name = self.name.split(":")[-1]
|
)
|
||||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||||
else:
|
|
||||||
return await self.mcp_client.session.call_tool(self.name, args)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown function origin: {self.origin}")
|
raise Exception(f"Unknown function origin: {self.origin}")
|
||||||
|
|
||||||
@@ -100,6 +164,7 @@ class MCPClient:
|
|||||||
self.active: bool = True
|
self.active: bool = True
|
||||||
self.tools: List[mcp.Tool] = []
|
self.tools: List[mcp.Tool] = []
|
||||||
self.server_errlogs: List[str] = []
|
self.server_errlogs: List[str] = []
|
||||||
|
self.running_event = asyncio.Event()
|
||||||
|
|
||||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||||
"""连接到 MCP 服务器
|
"""连接到 MCP 服务器
|
||||||
@@ -112,17 +177,19 @@ class MCPClient:
|
|||||||
Args:
|
Args:
|
||||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||||
"""
|
"""
|
||||||
cfg = mcp_server_config.copy()
|
cfg = _prepare_config(mcp_server_config.copy())
|
||||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
|
||||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
def logging_callback(msg: str):
|
||||||
cfg = cfg["mcpServers"][key_0]
|
# 处理 MCP 服务的错误日志
|
||||||
cfg.pop("active", None) # Remove active flag from config
|
print(f"MCP Server {name} Error: {msg}")
|
||||||
|
self.server_errlogs.append(msg)
|
||||||
|
|
||||||
if "url" in cfg:
|
if "url" in cfg:
|
||||||
is_sse = True
|
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||||
if cfg.get("transport") == "streamable_http":
|
if not success:
|
||||||
is_sse = False
|
raise Exception(error_msg)
|
||||||
if is_sse:
|
|
||||||
|
if cfg.get("transport") != "streamable_http":
|
||||||
# SSE transport method
|
# SSE transport method
|
||||||
self._streams_context = sse_client(
|
self._streams_context = sse_client(
|
||||||
url=cfg["url"],
|
url=cfg["url"],
|
||||||
@@ -130,11 +197,18 @@ class MCPClient:
|
|||||||
timeout=cfg.get("timeout", 5),
|
timeout=cfg.get("timeout", 5),
|
||||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||||
)
|
)
|
||||||
streams = await self._streams_context.__aenter__()
|
streams = await self.exit_stack.enter_async_context(
|
||||||
|
self._streams_context
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
|
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
mcp.ClientSession(*streams)
|
mcp.ClientSession(
|
||||||
|
*streams,
|
||||||
|
read_timeout_seconds=read_timeout,
|
||||||
|
logging_callback=logging_callback, # type: ignore
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||||
@@ -148,11 +222,19 @@ class MCPClient:
|
|||||||
sse_read_timeout=sse_read_timeout,
|
sse_read_timeout=sse_read_timeout,
|
||||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||||
)
|
)
|
||||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||||
|
self._streams_context
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
|
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
mcp.ClientSession(
|
||||||
|
read_stream=read_s,
|
||||||
|
write_stream=write_s,
|
||||||
|
read_timeout_seconds=read_timeout,
|
||||||
|
logging_callback=logging_callback, # type: ignore
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -172,7 +254,7 @@ class MCPClient:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
identifier=f"MCPServer-{name}",
|
identifier=f"MCPServer-{name}",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
),
|
), # type: ignore
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -180,19 +262,18 @@ class MCPClient:
|
|||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
mcp.ClientSession(*stdio_transport)
|
mcp.ClientSession(*stdio_transport)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.session.initialize()
|
await self.session.initialize()
|
||||||
|
|
||||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
"""List all tools from the server and save them to self.tools"""
|
"""List all tools from the server and save them to self.tools"""
|
||||||
response = await self.session.list_tools()
|
response = await self.session.list_tools()
|
||||||
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
|
||||||
self.tools = response.tools
|
self.tools = response.tools
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""Clean up resources"""
|
"""Clean up resources"""
|
||||||
await self.exit_stack.aclose()
|
await self.exit_stack.aclose()
|
||||||
|
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||||
|
|
||||||
|
|
||||||
class FuncCall:
|
class FuncCall:
|
||||||
@@ -201,8 +282,6 @@ class FuncCall:
|
|||||||
"""内部加载的 func tools"""
|
"""内部加载的 func tools"""
|
||||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||||
"""MCP 服务列表"""
|
"""MCP 服务列表"""
|
||||||
self.mcp_service_queue = asyncio.Queue()
|
|
||||||
"""用于外部控制 MCP 服务的启停"""
|
|
||||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
@@ -258,7 +337,7 @@ class FuncCall:
|
|||||||
return f
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _init_mcp_clients(self) -> None:
|
async def init_mcp_clients(self) -> None:
|
||||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||||
```
|
```
|
||||||
{
|
{
|
||||||
@@ -300,113 +379,64 @@ class FuncCall:
|
|||||||
)
|
)
|
||||||
self.mcp_client_event[name] = event
|
self.mcp_client_event[name] = event
|
||||||
|
|
||||||
async def mcp_service_selector(self):
|
|
||||||
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
|
||||||
|
|
||||||
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
|
||||||
|
|
||||||
{"type": "init"} 初始化所有MCP客户端
|
|
||||||
|
|
||||||
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
|
||||||
|
|
||||||
{"type": "terminate"} 终止所有MCP客户端
|
|
||||||
|
|
||||||
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
data = await self.mcp_service_queue.get()
|
|
||||||
if data["type"] == "init":
|
|
||||||
if "name" in data:
|
|
||||||
event = asyncio.Event()
|
|
||||||
asyncio.create_task(
|
|
||||||
self._init_mcp_client_task_wrapper(
|
|
||||||
data["name"], data["cfg"], event
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.mcp_client_event[data["name"]] = event
|
|
||||||
else:
|
|
||||||
await self._init_mcp_clients()
|
|
||||||
elif data["type"] == "terminate":
|
|
||||||
if "name" in data:
|
|
||||||
# await self._terminate_mcp_client(data["name"])
|
|
||||||
if data["name"] in self.mcp_client_event:
|
|
||||||
self.mcp_client_event[data["name"]].set()
|
|
||||||
self.mcp_client_event.pop(data["name"], None)
|
|
||||||
self.func_list = [
|
|
||||||
f
|
|
||||||
for f in self.func_list
|
|
||||||
if not (
|
|
||||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
for name in self.mcp_client_dict.keys():
|
|
||||||
# await self._terminate_mcp_client(name)
|
|
||||||
# self.mcp_client_event[name].set()
|
|
||||||
if name in self.mcp_client_event:
|
|
||||||
self.mcp_client_event[name].set()
|
|
||||||
self.mcp_client_event.pop(name, None)
|
|
||||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
|
||||||
|
|
||||||
async def _init_mcp_client_task_wrapper(
|
async def _init_mcp_client_task_wrapper(
|
||||||
self, name: str, cfg: dict, event: asyncio.Event
|
self,
|
||||||
|
name: str,
|
||||||
|
cfg: dict,
|
||||||
|
event: asyncio.Event,
|
||||||
|
ready_future: asyncio.Future = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
try:
|
try:
|
||||||
await self._init_mcp_client(name, cfg)
|
await self._init_mcp_client(name, cfg)
|
||||||
|
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||||
|
if ready_future and not ready_future.done():
|
||||||
|
# tell the caller we are ready
|
||||||
|
ready_future.set_result(tools)
|
||||||
await event.wait()
|
await event.wait()
|
||||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||||
await self._terminate_mcp_client(name)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||||
|
if ready_future and not ready_future.done():
|
||||||
traceback.print_exc()
|
ready_future.set_exception(e)
|
||||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
finally:
|
||||||
|
# 无论如何都能清理
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
|
||||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||||
"""初始化单个MCP客户端"""
|
"""初始化单个MCP客户端"""
|
||||||
try:
|
# 先清理之前的客户端,如果存在
|
||||||
# 先清理之前的客户端,如果存在
|
if name in self.mcp_client_dict:
|
||||||
if name in self.mcp_client_dict:
|
await self._terminate_mcp_client(name)
|
||||||
await self._terminate_mcp_client(name)
|
|
||||||
|
|
||||||
mcp_client = MCPClient()
|
mcp_client = MCPClient()
|
||||||
mcp_client.name = name
|
mcp_client.name = name
|
||||||
self.mcp_client_dict[name] = mcp_client
|
self.mcp_client_dict[name] = mcp_client
|
||||||
await mcp_client.connect_to_server(config, name)
|
await mcp_client.connect_to_server(config, name)
|
||||||
tools_res = await mcp_client.list_tools_and_save()
|
tools_res = await mcp_client.list_tools_and_save()
|
||||||
tool_names = [tool.name for tool in tools_res.tools]
|
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||||
|
tool_names = [tool.name for tool in tools_res.tools]
|
||||||
|
|
||||||
# 移除该MCP服务之前的工具(如有)
|
# 移除该MCP服务之前的工具(如有)
|
||||||
self.func_list = [
|
self.func_list = [
|
||||||
f
|
f
|
||||||
for f in self.func_list
|
for f in self.func_list
|
||||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||||
for tool in mcp_client.tools:
|
for tool in mcp_client.tools:
|
||||||
func_tool = FuncTool(
|
func_tool = FuncTool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
parameters=tool.inputSchema,
|
parameters=tool.inputSchema,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
origin="mcp",
|
origin="mcp",
|
||||||
mcp_server_name=name,
|
mcp_server_name=name,
|
||||||
mcp_client=mcp_client,
|
mcp_client=mcp_client,
|
||||||
)
|
)
|
||||||
self.func_list.append(func_tool)
|
self.func_list.append(func_tool)
|
||||||
|
|
||||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
|
||||||
# 发生错误时确保客户端被清理
|
|
||||||
if name in self.mcp_client_dict:
|
|
||||||
await self._terminate_mcp_client(name)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _terminate_mcp_client(self, name: str) -> None:
|
async def _terminate_mcp_client(self, name: str) -> None:
|
||||||
"""关闭并清理MCP客户端"""
|
"""关闭并清理MCP客户端"""
|
||||||
@@ -414,9 +444,9 @@ class FuncCall:
|
|||||||
try:
|
try:
|
||||||
# 关闭MCP连接
|
# 关闭MCP连接
|
||||||
await self.mcp_client_dict[name].cleanup()
|
await self.mcp_client_dict[name].cleanup()
|
||||||
del self.mcp_client_dict[name]
|
self.mcp_client_dict.pop(name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||||
# 移除关联的FuncTool
|
# 移除关联的FuncTool
|
||||||
self.func_list = [
|
self.func_list = [
|
||||||
f
|
f
|
||||||
@@ -425,6 +455,103 @@ class FuncCall:
|
|||||||
]
|
]
|
||||||
logger.info(f"已关闭 MCP 服务 {name}")
|
logger.info(f"已关闭 MCP 服务 {name}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||||
|
if "url" in config:
|
||||||
|
success, error_msg = await _quick_test_mcp_connection(config)
|
||||||
|
if not success:
|
||||||
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
mcp_client = MCPClient()
|
||||||
|
try:
|
||||||
|
logger.debug(f"testing MCP server connection with config: {config}")
|
||||||
|
await mcp_client.connect_to_server(config, "test")
|
||||||
|
tools_res = await mcp_client.list_tools_and_save()
|
||||||
|
tool_names = [tool.name for tool in tools_res.tools]
|
||||||
|
finally:
|
||||||
|
logger.debug("Cleaning up MCP client after testing connection.")
|
||||||
|
await mcp_client.cleanup()
|
||||||
|
return tool_names
|
||||||
|
|
||||||
|
async def enable_mcp_server(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: dict,
|
||||||
|
event: asyncio.Event | None = None,
|
||||||
|
ready_future: asyncio.Future | None = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
) -> None:
|
||||||
|
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the MCP server.
|
||||||
|
config (dict): Configuration for the MCP server.
|
||||||
|
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||||
|
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||||
|
timeout (int): Timeout for the initialization.
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||||
|
Exception: If there is an error during initialization.
|
||||||
|
"""
|
||||||
|
if not event:
|
||||||
|
event = asyncio.Event()
|
||||||
|
if not ready_future:
|
||||||
|
ready_future = asyncio.Future()
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
return
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||||
|
finally:
|
||||||
|
self.mcp_client_event[name] = event
|
||||||
|
|
||||||
|
if ready_future.done() and ready_future.exception():
|
||||||
|
exc = ready_future.exception()
|
||||||
|
if exc is not None:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def disable_mcp_server(
|
||||||
|
self, name: str | None = None, timeout: float = 10
|
||||||
|
) -> None:
|
||||||
|
"""Disable an MCP server by its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||||
|
timeout (int): Timeout.
|
||||||
|
"""
|
||||||
|
if name:
|
||||||
|
if name not in self.mcp_client_event:
|
||||||
|
return
|
||||||
|
client = self.mcp_client_dict.get(name)
|
||||||
|
self.mcp_client_event[name].set()
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
client_running_event = client.running_event
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||||
|
finally:
|
||||||
|
self.mcp_client_event.pop(name, None)
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if f.origin != "mcp" or f.mcp_server_name != name
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
running_events = [
|
||||||
|
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||||
|
]
|
||||||
|
for key, event in self.mcp_client_event.items():
|
||||||
|
event.set()
|
||||||
|
# waiting for all clients to finish
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||||
|
finally:
|
||||||
|
self.mcp_client_event.clear()
|
||||||
|
self.mcp_client_dict.clear()
|
||||||
|
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||||
|
|
||||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||||
"""
|
"""
|
||||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||||
@@ -629,8 +756,3 @@ class FuncCall:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.func_list)
|
return str(self.func_list)
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
for name in self.mcp_client_dict.keys():
|
|
||||||
await self._terminate_mcp_client(name)
|
|
||||||
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import traceback
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
import traceback
|
||||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
|
||||||
from .entities import ProviderType
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from .register import provider_cls_map, llm_tools
|
|
||||||
from astrbot.core import logger, sp
|
from astrbot.core import logger, sp
|
||||||
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
|
||||||
|
from .entities import ProviderType
|
||||||
|
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||||
|
from .register import llm_tools, provider_cls_map
|
||||||
|
|
||||||
|
|
||||||
class ProviderManager:
|
class ProviderManager:
|
||||||
@@ -91,17 +93,17 @@ class ProviderManager:
|
|||||||
"""加载的 Speech To Text Provider 的实例"""
|
"""加载的 Speech To Text Provider 的实例"""
|
||||||
self.tts_provider_insts: List[TTSProvider] = []
|
self.tts_provider_insts: List[TTSProvider] = []
|
||||||
"""加载的 Text To Speech Provider 的实例"""
|
"""加载的 Text To Speech Provider 的实例"""
|
||||||
self.embedding_provider_insts: List[Provider] = []
|
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||||
"""加载的 Embedding Provider 的实例"""
|
"""加载的 Embedding Provider 的实例"""
|
||||||
self.inst_map = {}
|
self.inst_map: dict[str, Provider] = {}
|
||||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||||
self.llm_tools = llm_tools
|
self.llm_tools = llm_tools
|
||||||
|
|
||||||
self.curr_provider_inst: Provider = None
|
self.curr_provider_inst: Provider | None = None
|
||||||
"""默认的 Provider 实例"""
|
"""默认的 Provider 实例"""
|
||||||
self.curr_stt_provider_inst: STTProvider = None
|
self.curr_stt_provider_inst: STTProvider | None = None
|
||||||
"""默认的 Speech To Text Provider 实例"""
|
"""默认的 Speech To Text Provider 实例"""
|
||||||
self.curr_tts_provider_inst: TTSProvider = None
|
self.curr_tts_provider_inst: TTSProvider | None = None
|
||||||
"""默认的 Text To Speech Provider 实例"""
|
"""默认的 Text To Speech Provider 实例"""
|
||||||
self.db_helper = db_helper
|
self.db_helper = db_helper
|
||||||
|
|
||||||
@@ -145,29 +147,29 @@ class ProviderManager:
|
|||||||
await self.load_provider(provider_config)
|
await self.load_provider(provider_config)
|
||||||
|
|
||||||
# 设置默认提供商
|
# 设置默认提供商
|
||||||
self.curr_provider_inst = self.inst_map.get(
|
selected_provider_id = sp.get(
|
||||||
self.provider_settings.get("default_provider_id")
|
"curr_provider", self.provider_settings.get("default_provider_id")
|
||||||
)
|
)
|
||||||
|
selected_stt_provider_id = sp.get(
|
||||||
|
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
|
||||||
|
)
|
||||||
|
selected_tts_provider_id = sp.get(
|
||||||
|
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
|
||||||
|
)
|
||||||
|
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||||
if not self.curr_provider_inst and self.provider_insts:
|
if not self.curr_provider_inst and self.provider_insts:
|
||||||
self.curr_provider_inst = self.provider_insts[0]
|
self.curr_provider_inst = self.provider_insts[0]
|
||||||
|
|
||||||
self.curr_stt_provider_inst = self.inst_map.get(
|
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
|
||||||
self.provider_stt_settings.get("provider_id")
|
|
||||||
)
|
|
||||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||||
|
|
||||||
self.curr_tts_provider_inst = self.inst_map.get(
|
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
|
||||||
self.provider_tts_settings.get("provider_id")
|
|
||||||
)
|
|
||||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||||
|
|
||||||
# 初始化 MCP Client 连接
|
# 初始化 MCP Client 连接
|
||||||
asyncio.create_task(
|
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||||
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
|
||||||
)
|
|
||||||
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
|
||||||
|
|
||||||
async def load_provider(self, provider_config: dict):
|
async def load_provider(self, provider_config: dict):
|
||||||
if not provider_config["enable"]:
|
if not provider_config["enable"]:
|
||||||
@@ -190,11 +192,6 @@ class ProviderManager:
|
|||||||
from .sources.anthropic_source import (
|
from .sources.anthropic_source import (
|
||||||
ProviderAnthropic as ProviderAnthropic,
|
ProviderAnthropic as ProviderAnthropic,
|
||||||
)
|
)
|
||||||
case "llm_tuner":
|
|
||||||
logger.info("加载 LLM Tuner 工具 ...")
|
|
||||||
from .sources.llmtuner_source import (
|
|
||||||
LLMTunerModelLoader as LLMTunerModelLoader,
|
|
||||||
)
|
|
||||||
case "dify":
|
case "dify":
|
||||||
from .sources.dify_source import ProviderDify as ProviderDify
|
from .sources.dify_source import ProviderDify as ProviderDify
|
||||||
case "dashscope":
|
case "dashscope":
|
||||||
@@ -253,6 +250,10 @@ class ProviderManager:
|
|||||||
from .sources.volcengine_tts import (
|
from .sources.volcengine_tts import (
|
||||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||||
)
|
)
|
||||||
|
case "gemini_tts":
|
||||||
|
from .sources.gemini_tts_source import (
|
||||||
|
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||||
|
)
|
||||||
case "openai_embedding":
|
case "openai_embedding":
|
||||||
from .sources.openai_embedding_source import (
|
from .sources.openai_embedding_source import (
|
||||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||||
@@ -326,8 +327,6 @@ class ProviderManager:
|
|||||||
inst = provider_metadata.cls_type(
|
inst = provider_metadata.cls_type(
|
||||||
provider_config,
|
provider_config,
|
||||||
self.provider_settings,
|
self.provider_settings,
|
||||||
self.db_helper,
|
|
||||||
self.provider_settings.get("persistant_history", True),
|
|
||||||
self.selected_default_persona,
|
self.selected_default_persona,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -420,7 +419,7 @@ class ProviderManager:
|
|||||||
self.curr_tts_provider_inst = None
|
self.curr_tts_provider_inst = None
|
||||||
|
|
||||||
if getattr(self.inst_map[provider_id], "terminate", None):
|
if getattr(self.inst_map[provider_id], "terminate", None):
|
||||||
await self.inst_map[provider_id].terminate()
|
await self.inst_map[provider_id].terminate() # type: ignore
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
||||||
@@ -430,6 +429,8 @@ class ProviderManager:
|
|||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
for provider_inst in self.provider_insts:
|
for provider_inst in self.provider_insts:
|
||||||
if hasattr(provider_inst, "terminate"):
|
if hasattr(provider_inst, "terminate"):
|
||||||
await provider_inst.terminate()
|
await provider_inst.terminate() # type: ignore
|
||||||
# 清理 MCP Client 连接
|
try:
|
||||||
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
await self.llm_tools.disable_mcp_server()
|
||||||
|
except Exception:
|
||||||
|
logger.error("Error while disabling MCP servers", exc_info=True)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import abc
|
import abc
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from typing import TypedDict, AsyncGenerator
|
from typing import TypedDict, AsyncGenerator
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
|
||||||
|
from astrbot.core.provider.register import provider_cls_map
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ class ProviderMeta:
|
|||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
type: str
|
type: str
|
||||||
|
provider_type: ProviderType
|
||||||
|
|
||||||
|
|
||||||
class AbstractProvider(abc.ABC):
|
class AbstractProvider(abc.ABC):
|
||||||
@@ -41,10 +42,14 @@ class AbstractProvider(abc.ABC):
|
|||||||
|
|
||||||
def meta(self) -> ProviderMeta:
|
def meta(self) -> ProviderMeta:
|
||||||
"""获取 Provider 的元数据"""
|
"""获取 Provider 的元数据"""
|
||||||
|
provider_type_name = self.provider_config["type"]
|
||||||
|
meta_data = provider_cls_map.get(provider_type_name)
|
||||||
|
provider_type = meta_data.provider_type if meta_data else None
|
||||||
return ProviderMeta(
|
return ProviderMeta(
|
||||||
id=self.provider_config["id"],
|
id=self.provider_config["id"],
|
||||||
model=self.get_model(),
|
model=self.get_model(),
|
||||||
type=self.provider_config["type"],
|
type=provider_type_name,
|
||||||
|
provider_type=provider_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -53,15 +58,13 @@ class Provider(AbstractProvider):
|
|||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config: dict,
|
||||||
provider_settings: dict,
|
provider_settings: dict,
|
||||||
persistant_history: bool = True,
|
default_persona: Personality | None = None,
|
||||||
db_helper: BaseDatabase = None,
|
|
||||||
default_persona: Personality = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(provider_config)
|
super().__init__(provider_config)
|
||||||
|
|
||||||
self.provider_settings = provider_settings
|
self.provider_settings = provider_settings
|
||||||
|
|
||||||
self.curr_personality: Personality = default_persona
|
self.curr_personality = default_persona
|
||||||
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -86,11 +89,12 @@ class Provider(AbstractProvider):
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
image_urls: List[str] = None,
|
image_urls: list[str] = None,
|
||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts: List = None,
|
contexts: list = None,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
tool_calls_result: ToolCallsResult = None,
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||||
|
model: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||||
@@ -114,11 +118,12 @@ class Provider(AbstractProvider):
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
image_urls: List[str] = None,
|
image_urls: list[str] = None,
|
||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts: List = None,
|
contexts: list = None,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
tool_calls_result: ToolCallsResult = None,
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||||
|
model: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import json
|
||||||
|
import anthropic
|
||||||
|
import base64
|
||||||
from typing import List
|
from typing import List
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
|
|
||||||
@@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic
|
|||||||
from anthropic.types import Message
|
from anthropic.types import Message
|
||||||
|
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.api.provider import Provider, Personality
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
from typing import AsyncGenerator
|
||||||
from .openai_source import ProviderOpenAIOfficial
|
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
@register_provider_adapter(
|
||||||
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
|
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
|
||||||
)
|
)
|
||||||
class ProviderAnthropic(ProviderOpenAIOfficial):
|
class ProviderAnthropic(Provider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config,
|
||||||
provider_settings: dict,
|
provider_settings,
|
||||||
db_helper: BaseDatabase,
|
default_persona=None,
|
||||||
persistant_history=True,
|
|
||||||
default_persona: Personality = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
# Skip OpenAI's __init__ and call Provider's __init__ directly
|
super().__init__(
|
||||||
Provider.__init__(
|
|
||||||
self,
|
|
||||||
provider_config,
|
provider_config,
|
||||||
provider_settings,
|
provider_settings,
|
||||||
persistant_history,
|
|
||||||
db_helper,
|
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chosen_api_key = None
|
self.chosen_api_key: str = ""
|
||||||
self.api_keys: List = provider_config.get("key", [])
|
self.api_keys: List = provider_config.get("key", [])
|
||||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||||
self.timeout = provider_config.get("timeout", 120)
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
if isinstance(self.timeout, str):
|
if isinstance(self.timeout, str):
|
||||||
@@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
|
|
||||||
self.set_model(provider_config["model_config"]["model"])
|
self.set_model(provider_config["model_config"]["model"])
|
||||||
|
|
||||||
|
def _prepare_payload(self, messages: list[dict]):
|
||||||
|
"""准备 Anthropic API 的请求 payload
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息
|
||||||
|
Returns:
|
||||||
|
system_prompt: 系统提示内容
|
||||||
|
new_messages: 处理后的消息列表,去除系统提示
|
||||||
|
"""
|
||||||
|
system_prompt = ""
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_prompt = message["content"]
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
blocks = []
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
blocks.append({"type": "text", "text": message["content"]})
|
||||||
|
if "tool_calls" in message:
|
||||||
|
for tool_call in message["tool_calls"]:
|
||||||
|
blocks.append( # noqa: PERF401
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"input": json.loads(tool_call["function"]["arguments"])
|
||||||
|
if isinstance(tool_call["function"]["arguments"], str)
|
||||||
|
else tool_call["function"]["arguments"],
|
||||||
|
"id": tool_call["id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": blocks,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": message["tool_call_id"],
|
||||||
|
"content": message["content"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_messages.append(message)
|
||||||
|
|
||||||
|
return system_prompt, new_messages
|
||||||
|
|
||||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||||
if tools:
|
if tools:
|
||||||
tool_list = tools.get_func_desc_anthropic_style()
|
if tool_list := tools.get_func_desc_anthropic_style():
|
||||||
if tool_list:
|
|
||||||
payloads["tools"] = tool_list
|
payloads["tools"] = tool_list
|
||||||
|
|
||||||
completion = await self.client.messages.create(**payloads, stream=False)
|
completion = await self.client.messages.create(**payloads, stream=False)
|
||||||
@@ -64,70 +112,158 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
|
|
||||||
if len(completion.content) == 0:
|
if len(completion.content) == 0:
|
||||||
raise Exception("API 返回的 completion 为空。")
|
raise Exception("API 返回的 completion 为空。")
|
||||||
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
|
|
||||||
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
|
|
||||||
content = completion.content[-1]
|
|
||||||
|
|
||||||
llm_response = LLMResponse("assistant")
|
llm_response = LLMResponse(role="assistant")
|
||||||
|
|
||||||
if content.type == "text":
|
for content_block in completion.content:
|
||||||
# text completion
|
if content_block.type == "text":
|
||||||
completion_text = str(content.text).strip()
|
completion_text = str(content_block.text).strip()
|
||||||
# llm_response.completion_text = completion_text
|
llm_response.completion_text = completion_text
|
||||||
llm_response.result_chain = MessageChain().message(completion_text)
|
|
||||||
|
|
||||||
# Anthropic每次只返回一个函数调用
|
|
||||||
if completion.stop_reason == "tool_use":
|
|
||||||
# tools call (function calling)
|
|
||||||
args_ls = []
|
|
||||||
func_name_ls = []
|
|
||||||
tool_use_ids = []
|
|
||||||
func_name_ls.append(content.name)
|
|
||||||
args_ls.append(content.input)
|
|
||||||
tool_use_ids.append(content.id)
|
|
||||||
llm_response.role = "tool"
|
|
||||||
llm_response.tools_call_args = args_ls
|
|
||||||
llm_response.tools_call_name = func_name_ls
|
|
||||||
llm_response.tools_call_ids = tool_use_ids
|
|
||||||
|
|
||||||
|
if content_block.type == "tool_use":
|
||||||
|
llm_response.tools_call_args.append(content_block.input)
|
||||||
|
llm_response.tools_call_name.append(content_block.name)
|
||||||
|
llm_response.tools_call_ids.append(content_block.id)
|
||||||
|
# TODO(Soulter): 处理 end_turn 情况
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
|
||||||
|
|
||||||
llm_response.raw_completion = completion
|
|
||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
async def _query_stream(
|
||||||
|
self, payloads: dict, tools: FuncCall
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
if tools:
|
||||||
|
if tool_list := tools.get_func_desc_anthropic_style():
|
||||||
|
payloads["tools"] = tool_list
|
||||||
|
|
||||||
|
# 用于累积工具调用信息
|
||||||
|
tool_use_buffer = {}
|
||||||
|
# 用于累积最终结果
|
||||||
|
final_text = ""
|
||||||
|
final_tool_calls = []
|
||||||
|
|
||||||
|
async with self.client.messages.stream(**payloads) as stream:
|
||||||
|
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||||
|
async for event in stream:
|
||||||
|
if event.type == "content_block_start":
|
||||||
|
if event.content_block.type == "text":
|
||||||
|
# 文本块开始
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant", completion_text="", is_chunk=True
|
||||||
|
)
|
||||||
|
elif event.content_block.type == "tool_use":
|
||||||
|
# 工具使用块开始,初始化缓冲区
|
||||||
|
tool_use_buffer[event.index] = {
|
||||||
|
"id": event.content_block.id,
|
||||||
|
"name": event.content_block.name,
|
||||||
|
"input": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
if event.delta.type == "text_delta":
|
||||||
|
# 文本增量
|
||||||
|
final_text += event.delta.text
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text=event.delta.text,
|
||||||
|
is_chunk=True,
|
||||||
|
)
|
||||||
|
elif event.delta.type == "input_json_delta":
|
||||||
|
# 工具调用参数增量
|
||||||
|
if event.index in tool_use_buffer:
|
||||||
|
# 累积 JSON 输入
|
||||||
|
if "input_json" not in tool_use_buffer[event.index]:
|
||||||
|
tool_use_buffer[event.index]["input_json"] = ""
|
||||||
|
tool_use_buffer[event.index]["input_json"] += (
|
||||||
|
event.delta.partial_json
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event.type == "content_block_stop":
|
||||||
|
# 内容块结束
|
||||||
|
if event.index in tool_use_buffer:
|
||||||
|
# 解析完整的工具调用
|
||||||
|
tool_info = tool_use_buffer[event.index]
|
||||||
|
try:
|
||||||
|
if "input_json" in tool_info:
|
||||||
|
tool_info["input"] = json.loads(tool_info["input_json"])
|
||||||
|
|
||||||
|
# 添加到最终结果
|
||||||
|
final_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": tool_info["id"],
|
||||||
|
"name": tool_info["name"],
|
||||||
|
"input": tool_info["input"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield LLMResponse(
|
||||||
|
role="tool",
|
||||||
|
completion_text="",
|
||||||
|
tools_call_args=[tool_info["input"]],
|
||||||
|
tools_call_name=[tool_info["name"]],
|
||||||
|
tools_call_ids=[tool_info["id"]],
|
||||||
|
is_chunk=True,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# JSON 解析失败,跳过这个工具调用
|
||||||
|
logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}")
|
||||||
|
|
||||||
|
# 清理缓冲区
|
||||||
|
del tool_use_buffer[event.index]
|
||||||
|
|
||||||
|
# 返回最终的完整结果
|
||||||
|
final_response = LLMResponse(
|
||||||
|
role="assistant", completion_text=final_text, is_chunk=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_tool_calls:
|
||||||
|
final_response.tools_call_args = [
|
||||||
|
call["input"] for call in final_tool_calls
|
||||||
|
]
|
||||||
|
final_response.tools_call_name = [call["name"] for call in final_tool_calls]
|
||||||
|
final_response.tools_call_ids = [call["id"] for call in final_tool_calls]
|
||||||
|
|
||||||
|
yield final_response
|
||||||
|
|
||||||
async def text_chat(
|
async def text_chat(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt,
|
||||||
session_id: str = None,
|
session_id=None,
|
||||||
image_urls: List[str] = [],
|
image_urls=None,
|
||||||
func_tool: FuncCall = None,
|
func_tool=None,
|
||||||
contexts=None,
|
contexts=None,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result: ToolCallsResult = None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
contexts = []
|
contexts = []
|
||||||
if not prompt:
|
|
||||||
prompt = "<image>"
|
|
||||||
|
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
context_query = [*contexts, new_record]
|
context_query = [*contexts, new_record]
|
||||||
|
if system_prompt:
|
||||||
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
for part in context_query:
|
for part in context_query:
|
||||||
if "_no_save" in part:
|
if "_no_save" in part:
|
||||||
del part["_no_save"]
|
del part["_no_save"]
|
||||||
|
|
||||||
|
# tool calls result
|
||||||
if tool_calls_result:
|
if tool_calls_result:
|
||||||
# 暂时这样写。
|
if not isinstance(tool_calls_result, list):
|
||||||
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
else:
|
||||||
|
for tcr in tool_calls_result:
|
||||||
|
context_query.extend(tcr.to_openai_messages())
|
||||||
|
|
||||||
|
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
|
model_config["model"] = model or self.get_model()
|
||||||
|
|
||||||
|
payloads = {"messages": new_messages, **model_config}
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
|
||||||
# Anthropic has a different way of handling system prompts
|
# Anthropic has a different way of handling system prompts
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
payloads["system"] = system_prompt
|
payloads["system"] = system_prompt
|
||||||
@@ -135,32 +271,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
llm_response = None
|
llm_response = None
|
||||||
try:
|
try:
|
||||||
llm_response = await self._query(payloads, func_tool)
|
llm_response = await self._query(payloads, func_tool)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "maximum context length" in str(e):
|
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||||
retry_cnt = 20
|
raise e
|
||||||
while retry_cnt > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await self.pop_record(context_query)
|
|
||||||
response = await self.client.messages.create(
|
|
||||||
messages=context_query, **model_config
|
|
||||||
)
|
|
||||||
llm_response = LLMResponse("assistant")
|
|
||||||
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
|
||||||
llm_response.raw_completion = response
|
|
||||||
return llm_response
|
|
||||||
except Exception as e:
|
|
||||||
if "maximum context length" in str(e):
|
|
||||||
retry_cnt -= 1
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
|
||||||
else:
|
|
||||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
@@ -173,23 +286,41 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
contexts=...,
|
contexts=...,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# raise NotImplementedError("This method is not implemented yet.")
|
if contexts is None:
|
||||||
# 调用 text_chat 模拟流式
|
contexts = []
|
||||||
llm_response = await self.text_chat(
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
prompt=prompt,
|
context_query = [*contexts, new_record]
|
||||||
session_id=session_id,
|
if system_prompt:
|
||||||
image_urls=image_urls,
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||||
func_tool=func_tool,
|
|
||||||
contexts=contexts,
|
for part in context_query:
|
||||||
system_prompt=system_prompt,
|
if "_no_save" in part:
|
||||||
tool_calls_result=tool_calls_result,
|
del part["_no_save"]
|
||||||
)
|
|
||||||
llm_response.is_chunk = True
|
# tool calls result
|
||||||
yield llm_response
|
if tool_calls_result:
|
||||||
llm_response.is_chunk = False
|
if not isinstance(tool_calls_result, list):
|
||||||
yield llm_response
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
else:
|
||||||
|
for tcr in tool_calls_result:
|
||||||
|
context_query.extend(tcr.to_openai_messages())
|
||||||
|
|
||||||
|
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||||
|
|
||||||
|
model_config = self.provider_config.get("model_config", {})
|
||||||
|
model_config["model"] = model or self.get_model()
|
||||||
|
|
||||||
|
payloads = {"messages": new_messages, **model_config}
|
||||||
|
|
||||||
|
# Anthropic has a different way of handling system prompts
|
||||||
|
if system_prompt:
|
||||||
|
payloads["system"] = system_prompt
|
||||||
|
|
||||||
|
async for llm_response in self._query_stream(payloads, func_tool):
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||||
"""组装上下文,支持文本和图片"""
|
"""组装上下文,支持文本和图片"""
|
||||||
@@ -232,3 +363,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {"role": "user", "content": content}
|
return {"role": "user", "content": content}
|
||||||
|
|
||||||
|
async def encode_image_bs64(self, image_url: str) -> str:
|
||||||
|
"""
|
||||||
|
将图片转换为 base64
|
||||||
|
"""
|
||||||
|
if image_url.startswith("base64://"):
|
||||||
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||||
|
with open(image_url, "rb") as f:
|
||||||
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
return "data:image/jpeg;base64," + image_bs64
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_current_key(self) -> str:
|
||||||
|
return self.chosen_api_key
|
||||||
|
|
||||||
|
async def get_models(self) -> List[str]:
|
||||||
|
models_str = []
|
||||||
|
models = await self.client.models.list()
|
||||||
|
models = sorted(models.data, key=lambda x: x.id)
|
||||||
|
for model in models:
|
||||||
|
models_str.append(model.id)
|
||||||
|
return models_str
|
||||||
|
|
||||||
|
def set_key(self, key: str):
|
||||||
|
self.chosen_api_key = key
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from ..register import register_provider_adapter
|
|||||||
TEMP_DIR = Path("data/temp/azure_tts")
|
TEMP_DIR = Path("data/temp/azure_tts")
|
||||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
class OTTSProvider:
|
class OTTSProvider:
|
||||||
def __init__(self, config: Dict):
|
def __init__(self, config: Dict):
|
||||||
self.skey = config["OTTS_SKEY"]
|
self.skey = config["OTTS_SKEY"]
|
||||||
@@ -70,12 +71,12 @@ class OTTSProvider:
|
|||||||
"style": voice_params["style"],
|
"style": voice_params["style"],
|
||||||
"role": voice_params["role"],
|
"role": voice_params["role"],
|
||||||
"rate": voice_params["rate"],
|
"rate": voice_params["rate"],
|
||||||
"volume": voice_params["volume"]
|
"volume": voice_params["volume"],
|
||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
"User-Agent": f"AstrBot/{VERSION}",
|
"User-Agent": f"AstrBot/{VERSION}",
|
||||||
"UAK": "AstrBot/AzureTTS"
|
"UAK": "AstrBot/AzureTTS",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -88,14 +89,19 @@ class OTTSProvider:
|
|||||||
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
|
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
|
||||||
await asyncio.sleep(0.5 * (attempt + 1))
|
await asyncio.sleep(0.5 * (attempt + 1))
|
||||||
|
|
||||||
|
|
||||||
class AzureNativeProvider(TTSProvider):
|
class AzureNativeProvider(TTSProvider):
|
||||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||||
super().__init__(provider_config, provider_settings)
|
super().__init__(provider_config, provider_settings)
|
||||||
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
|
self.subscription_key = provider_config.get(
|
||||||
|
"azure_tts_subscription_key", ""
|
||||||
|
).strip()
|
||||||
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||||
raise ValueError("无效的Azure订阅密钥")
|
raise ValueError("无效的Azure订阅密钥")
|
||||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||||
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
self.endpoint = (
|
||||||
|
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||||
|
)
|
||||||
self.client = None
|
self.client = None
|
||||||
self.token = None
|
self.token = None
|
||||||
self.token_expire = 0
|
self.token_expire = 0
|
||||||
@@ -104,15 +110,17 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
"style": provider_config.get("azure_tts_style", "cheerful"),
|
"style": provider_config.get("azure_tts_style", "cheerful"),
|
||||||
"role": provider_config.get("azure_tts_role", "Boy"),
|
"role": provider_config.get("azure_tts_role", "Boy"),
|
||||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||||
"volume": provider_config.get("azure_tts_volume", "100")
|
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.client = AsyncClient(headers={
|
self.client = AsyncClient(
|
||||||
"User-Agent": f"AstrBot/{VERSION}",
|
headers={
|
||||||
"Content-Type": "application/ssml+xml",
|
"User-Agent": f"AstrBot/{VERSION}",
|
||||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
|
"Content-Type": "application/ssml+xml",
|
||||||
})
|
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
|
||||||
|
}
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
@@ -120,10 +128,11 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
await self.client.aclose()
|
await self.client.aclose()
|
||||||
|
|
||||||
async def _refresh_token(self):
|
async def _refresh_token(self):
|
||||||
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
token_url = (
|
||||||
|
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
||||||
|
)
|
||||||
response = await self.client.post(
|
response = await self.client.post(
|
||||||
token_url,
|
token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
||||||
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
self.token = response.text
|
self.token = response.text
|
||||||
@@ -150,8 +159,8 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
content=ssml,
|
content=ssml,
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {self.token}",
|
"Authorization": f"Bearer {self.token}",
|
||||||
"User-Agent": f"AstrBot/{VERSION}"
|
"User-Agent": f"AstrBot/{VERSION}",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -160,6 +169,7 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
return str(file_path.resolve())
|
return str(file_path.resolve())
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
||||||
class AzureTTSProvider(TTSProvider):
|
class AzureTTSProvider(TTSProvider):
|
||||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||||
@@ -183,7 +193,7 @@ class AzureTTSProvider(TTSProvider):
|
|||||||
error_msg = (
|
error_msg = (
|
||||||
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
||||||
f"错误详情: {e.msg}\n"
|
f"错误详情: {e.msg}\n"
|
||||||
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
|
f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}"
|
||||||
)
|
)
|
||||||
raise ValueError(error_msg) from e
|
raise ValueError(error_msg) from e
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
@@ -202,8 +212,8 @@ class AzureTTSProvider(TTSProvider):
|
|||||||
"style": self.provider_config.get("azure_tts_style"),
|
"style": self.provider_config.get("azure_tts_style"),
|
||||||
"role": self.provider_config.get("azure_tts_role"),
|
"role": self.provider_config.get("azure_tts_role"),
|
||||||
"rate": self.provider_config.get("azure_tts_rate"),
|
"rate": self.provider_config.get("azure_tts_rate"),
|
||||||
"volume": self.provider_config.get("azure_tts_volume")
|
"volume": self.provider_config.get("azure_tts_volume"),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
async with self.provider as provider:
|
async with self.provider as provider:
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import List
|
|||||||
from .. import Provider, Personality
|
from .. import Provider, Personality
|
||||||
from ..entities import LLMResponse
|
from ..entities import LLMResponse
|
||||||
from ..func_tool_manager import FuncCall
|
from ..func_tool_manager import FuncCall
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from .openai_source import ProviderOpenAIOfficial
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
@@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config: dict,
|
||||||
provider_settings: dict,
|
provider_settings: dict,
|
||||||
db_helper: BaseDatabase,
|
default_persona: Personality | None = None,
|
||||||
persistant_history=False,
|
|
||||||
default_persona: Personality = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
Provider.__init__(
|
Provider.__init__(
|
||||||
self,
|
self,
|
||||||
provider_config,
|
provider_config,
|
||||||
provider_settings,
|
provider_settings,
|
||||||
persistant_history,
|
|
||||||
db_helper,
|
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||||
@@ -72,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts: List = None,
|
contexts: List = None,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
@@ -168,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
contexts=...,
|
contexts=...,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# raise NotImplementedError("This method is not implemented yet.")
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from .. import Provider, Personality
|
from .. import Provider
|
||||||
from ..entities import LLMResponse
|
from ..entities import LLMResponse
|
||||||
from ..func_tool_manager import FuncCall
|
from ..func_tool_manager import FuncCall
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||||
@@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
class ProviderDify(Provider):
|
class ProviderDify(Provider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config,
|
||||||
provider_settings: dict,
|
provider_settings,
|
||||||
db_helper: BaseDatabase,
|
default_persona=None,
|
||||||
persistant_history=False,
|
|
||||||
default_persona: Personality = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
provider_config,
|
provider_config,
|
||||||
provider_settings,
|
provider_settings,
|
||||||
persistant_history,
|
|
||||||
db_helper,
|
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.api_key = provider_config.get("dify_api_key", "")
|
self.api_key = provider_config.get("dify_api_key", "")
|
||||||
@@ -65,12 +60,14 @@ class ProviderDify(Provider):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts: List = None,
|
contexts: List = None,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if image_urls is None:
|
if image_urls is None:
|
||||||
image_urls = []
|
image_urls = []
|
||||||
result = ""
|
result = ""
|
||||||
session_id = session_id or kwargs.get("user") # 1734
|
session_id = session_id or kwargs.get("user") or "unknown" # 1734
|
||||||
conversation_id = self.conversation_ids.get(session_id, "")
|
conversation_id = self.conversation_ids.get(session_id, "")
|
||||||
|
|
||||||
files_payload = []
|
files_payload = []
|
||||||
@@ -103,6 +100,7 @@ class ProviderDify(Provider):
|
|||||||
session_vars = sp.get("session_variables", {})
|
session_vars = sp.get("session_variables", {})
|
||||||
session_var = session_vars.get(session_id, {})
|
session_var = session_vars.get(session_id, {})
|
||||||
payload_vars.update(session_var)
|
payload_vars.update(session_var)
|
||||||
|
payload_vars["system_prompt"] = system_prompt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
match self.api_type:
|
match self.api_type:
|
||||||
@@ -202,6 +200,7 @@ class ProviderDify(Provider):
|
|||||||
contexts=...,
|
contexts=...,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# raise NotImplementedError("This method is not implemented yet.")
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
|||||||
@@ -12,10 +12,9 @@ from google.genai.errors import APIError
|
|||||||
|
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Personality, Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
@@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config,
|
||||||
provider_settings: dict,
|
provider_settings,
|
||||||
db_helper: BaseDatabase,
|
default_persona=None,
|
||||||
persistant_history=True,
|
|
||||||
default_persona: Personality = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
provider_config,
|
provider_config,
|
||||||
provider_settings,
|
provider_settings,
|
||||||
persistant_history,
|
|
||||||
db_helper,
|
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.api_keys: list = provider_config.get("key", [])
|
self.api_keys: list = provider_config.get("key", [])
|
||||||
@@ -475,6 +470,10 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
raise
|
raise
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Accumulate the complete response text for the final response
|
||||||
|
accumulated_text = ""
|
||||||
|
final_response = None
|
||||||
|
|
||||||
async for chunk in result:
|
async for chunk in result:
|
||||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||||
|
|
||||||
@@ -486,32 +485,47 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
chunk, llm_response
|
chunk, llm_response
|
||||||
)
|
)
|
||||||
yield llm_response
|
yield llm_response
|
||||||
break
|
return
|
||||||
|
|
||||||
if chunk.text:
|
if chunk.text:
|
||||||
|
accumulated_text += chunk.text
|
||||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||||
yield llm_response
|
yield llm_response
|
||||||
|
|
||||||
if chunk.candidates[0].finish_reason:
|
if chunk.candidates[0].finish_reason:
|
||||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
# Process the final chunk for potential tool calls or other content
|
||||||
if not chunk.candidates[0].content.parts:
|
if chunk.candidates[0].content.parts:
|
||||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
final_response = LLMResponse("assistant", is_chunk=False)
|
||||||
else:
|
final_response.result_chain = self._process_content_parts(
|
||||||
llm_response.result_chain = self._process_content_parts(
|
chunk, final_response
|
||||||
chunk, llm_response
|
|
||||||
)
|
)
|
||||||
yield llm_response
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Yield final complete response with accumulated text
|
||||||
|
if not final_response:
|
||||||
|
final_response = LLMResponse("assistant", is_chunk=False)
|
||||||
|
|
||||||
|
# Set the complete accumulated text in the final response
|
||||||
|
if accumulated_text:
|
||||||
|
final_response.result_chain = MessageChain(
|
||||||
|
chain=[Comp.Plain(accumulated_text)]
|
||||||
|
)
|
||||||
|
elif not final_response.result_chain:
|
||||||
|
# If no text was accumulated and no final response was set, provide empty space
|
||||||
|
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||||
|
|
||||||
|
yield final_response
|
||||||
|
|
||||||
async def text_chat(
|
async def text_chat(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str = None,
|
session_id=None,
|
||||||
image_urls: list[str] = None,
|
image_urls=None,
|
||||||
func_tool: FuncCall = None,
|
func_tool=None,
|
||||||
contexts: list = None,
|
contexts=None,
|
||||||
system_prompt: str = None,
|
system_prompt=None,
|
||||||
tool_calls_result: ToolCallsResult = None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
@@ -527,10 +541,14 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
# tool calls result
|
# tool calls result
|
||||||
if tool_calls_result:
|
if tool_calls_result:
|
||||||
context_query.extend(tool_calls_result.to_openai_messages())
|
if not isinstance(tool_calls_result, list):
|
||||||
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
else:
|
||||||
|
for tcr in tool_calls_result:
|
||||||
|
context_query.extend(tcr.to_openai_messages())
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
model_config["model"] = self.get_model()
|
model_config["model"] = model or self.get_model()
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
|
|
||||||
@@ -547,13 +565,14 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
async def text_chat_stream(
|
async def text_chat_stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt,
|
||||||
session_id: str = None,
|
session_id=None,
|
||||||
image_urls: list[str] = None,
|
image_urls=None,
|
||||||
func_tool: FuncCall = None,
|
func_tool=None,
|
||||||
contexts: str = None,
|
contexts=None,
|
||||||
system_prompt: str = None,
|
system_prompt=None,
|
||||||
tool_calls_result: ToolCallsResult = None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
@@ -569,10 +588,14 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
# tool calls result
|
# tool calls result
|
||||||
if tool_calls_result:
|
if tool_calls_result:
|
||||||
context_query.extend(tool_calls_result.to_openai_messages())
|
if not isinstance(tool_calls_result, list):
|
||||||
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
else:
|
||||||
|
for tcr in tool_calls_result:
|
||||||
|
context_query.extend(tcr.to_openai_messages())
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
model_config["model"] = self.get_model()
|
model_config["model"] = model or self.get_model()
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
|
|
||||||
@@ -632,7 +655,10 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
continue
|
continue
|
||||||
user_content["content"].append(
|
user_content["content"].append(
|
||||||
{"type": "image_url", "image_url": {"url": image_data}}
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_data},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return user_content
|
return user_content
|
||||||
else:
|
else:
|
||||||
|
|||||||
79
astrbot/core/provider/sources/gemini_tts_source.py
Normal file
79
astrbot/core/provider/sources/gemini_tts_source.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
from ..entities import ProviderType
|
||||||
|
from ..provider import TTSProvider
|
||||||
|
from ..register import register_provider_adapter
|
||||||
|
|
||||||
|
|
||||||
|
@register_provider_adapter(
|
||||||
|
"gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||||
|
)
|
||||||
|
class ProviderGeminiTTSAPI(TTSProvider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_config: dict,
|
||||||
|
provider_settings: dict,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(provider_config, provider_settings)
|
||||||
|
api_key: str = provider_config.get("gemini_tts_api_key", "")
|
||||||
|
api_base: str | None = provider_config.get("gemini_tts_api_base")
|
||||||
|
timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
|
||||||
|
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||||
|
|
||||||
|
if api_base:
|
||||||
|
if api_base.endswith("/"):
|
||||||
|
api_base = api_base[:-1]
|
||||||
|
http_options.base_url = api_base
|
||||||
|
|
||||||
|
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||||
|
self.model: str = provider_config.get(
|
||||||
|
"gemini_tts_model", "gemini-2.5-flash-preview-tts"
|
||||||
|
)
|
||||||
|
self.prefix: str | None = provider_config.get(
|
||||||
|
"gemini_tts_prefix",
|
||||||
|
)
|
||||||
|
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
|
||||||
|
|
||||||
|
async def get_audio(self, text: str) -> str:
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
|
||||||
|
prompt = f"{self.prefix}: {text}" if self.prefix else text
|
||||||
|
response = await self.client.models.generate_content(
|
||||||
|
model=self.model,
|
||||||
|
contents=prompt,
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
response_modalities=["AUDIO"],
|
||||||
|
speech_config=types.SpeechConfig(
|
||||||
|
voice_config=types.VoiceConfig(
|
||||||
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||||
|
voice_name=self.voice_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 不想看类型检查报错
|
||||||
|
if (
|
||||||
|
not response.candidates
|
||||||
|
or not response.candidates[0].content
|
||||||
|
or not response.candidates[0].content.parts
|
||||||
|
or not response.candidates[0].content.parts[0].inline_data
|
||||||
|
or not response.candidates[0].content.parts[0].inline_data.data
|
||||||
|
):
|
||||||
|
raise Exception("No audio content returned from Gemini TTS API.")
|
||||||
|
|
||||||
|
with wave.open(path, "wb") as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setframerate(24000)
|
||||||
|
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
|
||||||
|
|
||||||
|
return path
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
import os
|
|
||||||
from llmtuner.chat import ChatModel
|
|
||||||
from typing import List
|
|
||||||
from .. import Provider
|
|
||||||
from ..entities import LLMResponse
|
|
||||||
from ..func_tool_manager import FuncCall
|
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from ..register import register_provider_adapter
|
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
|
||||||
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
|
|
||||||
)
|
|
||||||
class LLMTunerModelLoader(Provider):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
provider_config: dict,
|
|
||||||
provider_settings: dict,
|
|
||||||
db_helper: BaseDatabase,
|
|
||||||
persistant_history=True,
|
|
||||||
default_persona=None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
provider_config,
|
|
||||||
provider_settings,
|
|
||||||
persistant_history,
|
|
||||||
db_helper,
|
|
||||||
default_persona,
|
|
||||||
)
|
|
||||||
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
|
||||||
provider_config["adapter_model_path"]
|
|
||||||
):
|
|
||||||
raise FileNotFoundError("模型文件路径不存在。")
|
|
||||||
self.base_model_path = provider_config["base_model_path"]
|
|
||||||
self.adapter_model_path = provider_config["adapter_model_path"]
|
|
||||||
self.model = ChatModel(
|
|
||||||
{
|
|
||||||
"model_name_or_path": self.base_model_path,
|
|
||||||
"adapter_name_or_path": self.adapter_model_path,
|
|
||||||
"template": provider_config["llmtuner_template"],
|
|
||||||
"finetuning_type": provider_config["finetuning_type"],
|
|
||||||
"quantization_bit": provider_config["quantization_bit"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.set_model(
|
|
||||||
os.path.basename(self.base_model_path)
|
|
||||||
+ "_"
|
|
||||||
+ os.path.basename(self.adapter_model_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
|
||||||
"""
|
|
||||||
组装上下文。
|
|
||||||
"""
|
|
||||||
return {"role": "user", "content": text}
|
|
||||||
|
|
||||||
async def text_chat(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
session_id: str = None,
|
|
||||||
image_urls: List[str] = None,
|
|
||||||
func_tool: FuncCall = None,
|
|
||||||
contexts: List = None,
|
|
||||||
system_prompt: str = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> LLMResponse:
|
|
||||||
if contexts is None:
|
|
||||||
contexts = []
|
|
||||||
system_prompt = ""
|
|
||||||
new_record = {"role": "user", "content": prompt}
|
|
||||||
query_context = [*contexts, new_record]
|
|
||||||
|
|
||||||
# 提取出系统提示
|
|
||||||
system_idxs = []
|
|
||||||
for idx, context in enumerate(query_context):
|
|
||||||
if context["role"] == "system":
|
|
||||||
system_idxs.append(idx)
|
|
||||||
|
|
||||||
if "_no_save" in context:
|
|
||||||
del context["_no_save"]
|
|
||||||
|
|
||||||
for idx in reversed(system_idxs):
|
|
||||||
system_prompt += " " + query_context.pop(idx)["content"]
|
|
||||||
|
|
||||||
conf = {
|
|
||||||
"messages": query_context,
|
|
||||||
"system": system_prompt,
|
|
||||||
}
|
|
||||||
if func_tool:
|
|
||||||
tool_list = func_tool.get_func_desc_openai_style()
|
|
||||||
if tool_list:
|
|
||||||
conf["tools"] = tool_list
|
|
||||||
|
|
||||||
responses = await self.model.achat(**conf)
|
|
||||||
|
|
||||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
|
||||||
|
|
||||||
return llm_response
|
|
||||||
|
|
||||||
async def text_chat_stream(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
session_id=None,
|
|
||||||
image_urls=...,
|
|
||||||
func_tool=None,
|
|
||||||
contexts=...,
|
|
||||||
system_prompt=None,
|
|
||||||
tool_calls_result=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# raise NotImplementedError("This method is not implemented yet.")
|
|
||||||
# 调用 text_chat 模拟流式
|
|
||||||
llm_response = await self.text_chat(
|
|
||||||
prompt=prompt,
|
|
||||||
session_id=session_id,
|
|
||||||
image_urls=image_urls,
|
|
||||||
func_tool=func_tool,
|
|
||||||
contexts=contexts,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tool_calls_result=tool_calls_result,
|
|
||||||
)
|
|
||||||
llm_response.is_chunk = True
|
|
||||||
yield llm_response
|
|
||||||
llm_response.is_chunk = False
|
|
||||||
yield llm_response
|
|
||||||
|
|
||||||
async def get_current_key(self):
|
|
||||||
return "none"
|
|
||||||
|
|
||||||
async def set_key(self, key):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def get_models(self):
|
|
||||||
return [self.get_model()]
|
|
||||||
@@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|||||||
timeout=int(provider_config.get("timeout", 20)),
|
timeout=int(provider_config.get("timeout", 20)),
|
||||||
)
|
)
|
||||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
||||||
|
|
||||||
async def get_embedding(self, text: str) -> list[float]:
|
async def get_embedding(self, text: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp
|
|||||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
||||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.api.provider import Provider, Personality
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from typing import List, AsyncGenerator
|
from typing import List, AsyncGenerator
|
||||||
@@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
|||||||
class ProviderOpenAIOfficial(Provider):
|
class ProviderOpenAIOfficial(Provider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config,
|
||||||
provider_settings: dict,
|
provider_settings,
|
||||||
db_helper: BaseDatabase,
|
default_persona=None,
|
||||||
persistant_history=True,
|
|
||||||
default_persona: Personality = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
provider_config,
|
provider_config,
|
||||||
provider_settings,
|
provider_settings,
|
||||||
persistant_history,
|
|
||||||
db_helper,
|
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.chosen_api_key = None
|
self.chosen_api_key = None
|
||||||
@@ -105,6 +99,11 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
for key in to_del:
|
for key in to_del:
|
||||||
del payloads[key]
|
del payloads[key]
|
||||||
|
|
||||||
|
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
|
||||||
|
model = payloads.get("model", "")
|
||||||
|
if "qwen3" in model.lower():
|
||||||
|
extra_body["enable_thinking"] = False
|
||||||
|
|
||||||
completion = await self.client.chat.completions.create(
|
completion = await self.client.chat.completions.create(
|
||||||
**payloads, stream=False, extra_body=extra_body
|
**payloads, stream=False, extra_body=extra_body
|
||||||
)
|
)
|
||||||
@@ -182,7 +181,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
raise Exception("API 返回的 completion 为空。")
|
raise Exception("API 返回的 completion 为空。")
|
||||||
choice = completion.choices[0]
|
choice = completion.choices[0]
|
||||||
|
|
||||||
if choice.message.content:
|
if choice.message.content is not None:
|
||||||
# text completion
|
# text completion
|
||||||
completion_text = str(choice.message.content).strip()
|
completion_text = str(choice.message.content).strip()
|
||||||
llm_response.result_chain = MessageChain().message(completion_text)
|
llm_response.result_chain = MessageChain().message(completion_text)
|
||||||
@@ -193,6 +192,9 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
func_name_ls = []
|
func_name_ls = []
|
||||||
tool_call_ids = []
|
tool_call_ids = []
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
|
if isinstance(tool_call, str):
|
||||||
|
# workaround for #1359
|
||||||
|
tool_call = json.loads(tool_call)
|
||||||
for tool in tools.func_list:
|
for tool in tools.func_list:
|
||||||
if tool.name == tool_call.function.name:
|
if tool.name == tool_call.function.name:
|
||||||
# workaround for #1454
|
# workaround for #1454
|
||||||
@@ -213,7 +215,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
|
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
||||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
|
|
||||||
@@ -224,12 +226,11 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
async def _prepare_chat_payload(
|
async def _prepare_chat_payload(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str = None,
|
image_urls: list[str] | None = None,
|
||||||
image_urls: list[str] = None,
|
contexts: list | None = None,
|
||||||
func_tool: FuncCall = None,
|
system_prompt: str | None = None,
|
||||||
contexts: list = None,
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||||
system_prompt: str = None,
|
model: str | None = None,
|
||||||
tool_calls_result: ToolCallsResult = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""准备聊天所需的有效载荷和上下文"""
|
"""准备聊天所需的有效载荷和上下文"""
|
||||||
@@ -246,14 +247,18 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
|
|
||||||
# tool calls result
|
# tool calls result
|
||||||
if tool_calls_result:
|
if tool_calls_result:
|
||||||
context_query.extend(tool_calls_result.to_openai_messages())
|
if isinstance(tool_calls_result, ToolCallsResult):
|
||||||
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
else:
|
||||||
|
for tcr in tool_calls_result:
|
||||||
|
context_query.extend(tcr.to_openai_messages())
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
model_config["model"] = self.get_model()
|
model_config["model"] = model or self.get_model()
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
|
|
||||||
return payloads, context_query, func_tool
|
return payloads, context_query
|
||||||
|
|
||||||
async def _handle_api_error(
|
async def _handle_api_error(
|
||||||
self,
|
self,
|
||||||
@@ -350,16 +355,16 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
contexts=None,
|
contexts=None,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
payloads, context_query = await self._prepare_chat_payload(
|
||||||
prompt,
|
prompt,
|
||||||
session_id,
|
|
||||||
image_urls,
|
image_urls,
|
||||||
func_tool,
|
|
||||||
contexts,
|
contexts,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
tool_calls_result,
|
tool_calls_result,
|
||||||
|
model=model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -419,17 +424,17 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
contexts=[],
|
contexts=[],
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
tool_calls_result=None,
|
tool_calls_result=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
"""流式对话,与服务商交互并逐步返回结果"""
|
"""流式对话,与服务商交互并逐步返回结果"""
|
||||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
payloads, context_query = await self._prepare_chat_payload(
|
||||||
prompt,
|
prompt,
|
||||||
session_id,
|
|
||||||
image_urls,
|
image_urls,
|
||||||
func_tool,
|
|
||||||
contexts,
|
contexts,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
tool_calls_result,
|
tool_calls_result,
|
||||||
|
model=model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -485,13 +490,8 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
"""
|
"""
|
||||||
new_contexts = []
|
new_contexts = []
|
||||||
|
|
||||||
flag = False
|
|
||||||
for context in contexts:
|
for context in contexts:
|
||||||
if flag:
|
if "content" in context and isinstance(context["content"], list):
|
||||||
flag = False # 删除 image 后,下一条(LLM 响应)也要删除
|
|
||||||
continue
|
|
||||||
if isinstance(context["content"], list):
|
|
||||||
flag = True
|
|
||||||
# continue
|
# continue
|
||||||
new_content = []
|
new_content = []
|
||||||
for item in context["content"]:
|
for item in context["content"]:
|
||||||
@@ -534,7 +534,10 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
continue
|
continue
|
||||||
user_content["content"].append(
|
user_content["content"].append(
|
||||||
{"type": "image_url", "image_url": {"url": image_data}}
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_data},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return user_content
|
return user_content
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
|
||||||
from ..provider import TTSProvider
|
from ..provider import TTSProvider
|
||||||
from ..entities import ProviderType
|
from ..entities import ProviderType
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
@register_provider_adapter(
|
||||||
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||||
)
|
)
|
||||||
@@ -22,7 +22,9 @@ class ProviderVolcengineTTS(TTSProvider):
|
|||||||
self.cluster = provider_config.get("volcengine_cluster", "")
|
self.cluster = provider_config.get("volcengine_cluster", "")
|
||||||
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
||||||
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
||||||
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
|
self.api_base = provider_config.get(
|
||||||
|
"api_base", "https://openspeech.bytedance.com/api/v1/tts"
|
||||||
|
)
|
||||||
self.timeout = provider_config.get("timeout", 20)
|
self.timeout = provider_config.get("timeout", 20)
|
||||||
|
|
||||||
def _build_request_payload(self, text: str) -> dict:
|
def _build_request_payload(self, text: str) -> dict:
|
||||||
@@ -30,11 +32,9 @@ class ProviderVolcengineTTS(TTSProvider):
|
|||||||
"app": {
|
"app": {
|
||||||
"appid": self.appid,
|
"appid": self.appid,
|
||||||
"token": self.api_key,
|
"token": self.api_key,
|
||||||
"cluster": self.cluster
|
"cluster": self.cluster,
|
||||||
},
|
|
||||||
"user": {
|
|
||||||
"uid": str(uuid.uuid4())
|
|
||||||
},
|
},
|
||||||
|
"user": {"uid": str(uuid.uuid4())},
|
||||||
"audio": {
|
"audio": {
|
||||||
"voice_type": self.voice_type,
|
"voice_type": self.voice_type,
|
||||||
"encoding": "mp3",
|
"encoding": "mp3",
|
||||||
@@ -48,15 +48,15 @@ class ProviderVolcengineTTS(TTSProvider):
|
|||||||
"text_type": "plain",
|
"text_type": "plain",
|
||||||
"operation": "query",
|
"operation": "query",
|
||||||
"with_frontend": 1,
|
"with_frontend": 1,
|
||||||
"frontend_type": "unitTson"
|
"frontend_type": "unitTson",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_audio(self, text: str) -> str:
|
async def get_audio(self, text: str) -> str:
|
||||||
"""异步方法获取语音文件路径"""
|
"""异步方法获取语音文件路径"""
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer; {self.api_key}"
|
"Authorization": f"Bearer; {self.api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = self._build_request_payload(text)
|
payload = self._build_request_payload(text)
|
||||||
@@ -71,7 +71,7 @@ class ProviderVolcengineTTS(TTSProvider):
|
|||||||
self.api_base,
|
self.api_base,
|
||||||
data=json.dumps(payload),
|
data=json.dumps(payload),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
) as response:
|
) as response:
|
||||||
logger.debug(f"响应状态码: {response.status}")
|
logger.debug(f"响应状态码: {response.status}")
|
||||||
|
|
||||||
@@ -90,8 +90,7 @@ class ProviderVolcengineTTS(TTSProvider):
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None,
|
None, lambda: open(file_path, "wb").write(audio_data)
|
||||||
lambda: open(file_path, "wb").write(audio_data)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
@@ -99,7 +98,9 @@ class ProviderVolcengineTTS(TTSProvider):
|
|||||||
error_msg = resp_data.get("message", "未知错误")
|
error_msg = resp_data.get("message", "未知错误")
|
||||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||||
else:
|
else:
|
||||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
raise Exception(
|
||||||
|
f"火山引擎 TTS API 请求失败: {response.status}, {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_details = traceback.format_exc()
|
error_details = traceback.format_exc()
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
|||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config: dict,
|
||||||
provider_settings: dict,
|
provider_settings: dict,
|
||||||
db_helper: BaseDatabase,
|
|
||||||
persistant_history=True,
|
|
||||||
default_persona=None,
|
default_persona=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
provider_config,
|
provider_config,
|
||||||
provider_settings,
|
provider_settings,
|
||||||
db_helper,
|
|
||||||
persistant_history,
|
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts=None,
|
contexts=None,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
|
model=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if contexts is None:
|
if contexts is None:
|
||||||
@@ -43,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
|||||||
context_query = [*contexts, new_record]
|
context_query = [*contexts, new_record]
|
||||||
|
|
||||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||||
model = self.get_model()
|
model = model or self.get_model()
|
||||||
# glm-4v-flash 只支持一张图片
|
# glm-4v-flash 只支持一张图片
|
||||||
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
|
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
|
||||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .star import StarMetadata
|
from .star import StarMetadata, star_map, star_registry
|
||||||
from .star_manager import PluginManager
|
from .star_manager import PluginManager
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
@@ -10,23 +10,48 @@ from astrbot.core.star.star_tools import StarTools
|
|||||||
class Star(CommandParserMixin):
|
class Star(CommandParserMixin):
|
||||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||||
|
|
||||||
def __init__(self, context: Context):
|
def __init__(self, context: Context, config: dict | None = None):
|
||||||
StarTools.initialize(context)
|
StarTools.initialize(context)
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if not star_map.get(cls.__module__):
|
||||||
|
metadata = StarMetadata(
|
||||||
|
star_cls_type=cls,
|
||||||
|
module_path=cls.__module__,
|
||||||
|
)
|
||||||
|
star_map[cls.__module__] = metadata
|
||||||
|
star_registry.append(metadata)
|
||||||
|
else:
|
||||||
|
star_map[cls.__module__].star_cls_type = cls
|
||||||
|
star_map[cls.__module__].module_path = cls.__module__
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def text_to_image(text: str, return_url=True) -> str:
|
||||||
"""将文本转换为图片"""
|
"""将文本转换为图片"""
|
||||||
return await html_renderer.render_t2i(text, return_url=return_url)
|
return await html_renderer.render_t2i(text, return_url=return_url)
|
||||||
|
|
||||||
async def html_render(self, tmpl: str, data: dict, return_url=True) -> str:
|
@staticmethod
|
||||||
|
async def html_render(
|
||||||
|
tmpl: str, data: dict, return_url=True, options: dict = None
|
||||||
|
) -> str:
|
||||||
"""渲染 HTML"""
|
"""渲染 HTML"""
|
||||||
return await html_renderer.render_custom_template(
|
return await html_renderer.render_custom_template(
|
||||||
tmpl, data, return_url=return_url
|
tmpl, data, return_url=return_url, options=options
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""当插件被激活时会调用这个方法"""
|
||||||
|
pass
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||||
|
|||||||
@@ -2,7 +2,12 @@ from asyncio import Queue
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
from astrbot.core.provider.provider import (
|
||||||
|
Provider,
|
||||||
|
TTSProvider,
|
||||||
|
STTProvider,
|
||||||
|
EmbeddingProvider,
|
||||||
|
)
|
||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
@@ -141,6 +146,10 @@ class Context:
|
|||||||
"""获取所有用于 STT 任务的 Provider。"""
|
"""获取所有用于 STT 任务的 Provider。"""
|
||||||
return self.provider_manager.stt_provider_insts
|
return self.provider_manager.stt_provider_insts
|
||||||
|
|
||||||
|
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
|
||||||
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||||
|
return self.provider_manager.embedding_provider_insts
|
||||||
|
|
||||||
def get_using_provider(self, umo: str = None) -> Provider:
|
def get_using_provider(self, umo: str = None) -> Provider:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||||
|
|||||||
@@ -7,10 +7,13 @@ from astrbot.core.config import AstrBotConfig
|
|||||||
from .custom_filter import CustomFilter
|
from .custom_filter import CustomFilter
|
||||||
from ..star_handler import StarHandlerMetadata
|
from ..star_handler import StarHandlerMetadata
|
||||||
|
|
||||||
|
|
||||||
class GreedyStr(str):
|
class GreedyStr(str):
|
||||||
"""标记指令完成其他参数接收后的所有剩余文本。"""
|
"""标记指令完成其他参数接收后的所有剩余文本。"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# 标准指令受到 wake_prefix 的制约。
|
# 标准指令受到 wake_prefix 的制约。
|
||||||
class CommandFilter(HandlerFilter):
|
class CommandFilter(HandlerFilter):
|
||||||
"""标准指令过滤器"""
|
"""标准指令过滤器"""
|
||||||
@@ -18,8 +21,8 @@ class CommandFilter(HandlerFilter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
command_name: str,
|
command_name: str,
|
||||||
alias: set = None,
|
alias: set | None = None,
|
||||||
handler_md: StarHandlerMetadata = None,
|
handler_md: StarHandlerMetadata | None = None,
|
||||||
parent_command_names: List[str] = [""],
|
parent_command_names: List[str] = [""],
|
||||||
):
|
):
|
||||||
self.command_name = command_name
|
self.command_name = command_name
|
||||||
@@ -110,6 +113,17 @@ class CommandFilter(HandlerFilter):
|
|||||||
elif isinstance(param_type_or_default_val, str):
|
elif isinstance(param_type_or_default_val, str):
|
||||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||||
result[param_name] = params[i]
|
result[param_name] = params[i]
|
||||||
|
elif isinstance(param_type_or_default_val, bool):
|
||||||
|
# 处理布尔类型
|
||||||
|
lower_param = str(params[i]).lower()
|
||||||
|
if lower_param in ["true", "yes", "1"]:
|
||||||
|
result[param_name] = True
|
||||||
|
elif lower_param in ["false", "no", "0"]:
|
||||||
|
result[param_name] = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"参数 {param_name} 必须是布尔值(true/false, yes/no, 1/0)。"
|
||||||
|
)
|
||||||
elif isinstance(param_type_or_default_val, int):
|
elif isinstance(param_type_or_default_val, int):
|
||||||
result[param_name] = int(params[i])
|
result[param_name] = int(params[i])
|
||||||
elif isinstance(param_type_or_default_val, float):
|
elif isinstance(param_type_or_default_val, float):
|
||||||
|
|||||||
@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||||
+ tree
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# complete_command_names = [name + " " for name in complete_command_names]
|
# complete_command_names = [name + " " for name in complete_command_names]
|
||||||
|
|||||||
@@ -8,22 +8,45 @@ from typing import Union
|
|||||||
class PlatformAdapterType(enum.Flag):
|
class PlatformAdapterType(enum.Flag):
|
||||||
AIOCQHTTP = enum.auto()
|
AIOCQHTTP = enum.auto()
|
||||||
QQOFFICIAL = enum.auto()
|
QQOFFICIAL = enum.auto()
|
||||||
VCHAT = enum.auto()
|
|
||||||
GEWECHAT = enum.auto()
|
|
||||||
TELEGRAM = enum.auto()
|
TELEGRAM = enum.auto()
|
||||||
WECOM = enum.auto()
|
WECOM = enum.auto()
|
||||||
LARK = enum.auto()
|
LARK = enum.auto()
|
||||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT | TELEGRAM | WECOM | LARK
|
WECHATPADPRO = enum.auto()
|
||||||
|
DINGTALK = enum.auto()
|
||||||
|
DISCORD = enum.auto()
|
||||||
|
SLACK = enum.auto()
|
||||||
|
KOOK = enum.auto()
|
||||||
|
VOCECHAT = enum.auto()
|
||||||
|
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||||
|
ALL = (
|
||||||
|
AIOCQHTTP
|
||||||
|
| QQOFFICIAL
|
||||||
|
| TELEGRAM
|
||||||
|
| WECOM
|
||||||
|
| LARK
|
||||||
|
| WECHATPADPRO
|
||||||
|
| DINGTALK
|
||||||
|
| DISCORD
|
||||||
|
| SLACK
|
||||||
|
| KOOK
|
||||||
|
| VOCECHAT
|
||||||
|
| WEIXIN_OFFICIAL_ACCOUNT
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_NAME_2_TYPE = {
|
ADAPTER_NAME_2_TYPE = {
|
||||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||||
"vchat": PlatformAdapterType.VCHAT,
|
|
||||||
"gewechat": PlatformAdapterType.GEWECHAT,
|
|
||||||
"telegram": PlatformAdapterType.TELEGRAM,
|
"telegram": PlatformAdapterType.TELEGRAM,
|
||||||
"wecom": PlatformAdapterType.WECOM,
|
"wecom": PlatformAdapterType.WECOM,
|
||||||
"lark": PlatformAdapterType.LARK,
|
"lark": PlatformAdapterType.LARK,
|
||||||
|
"dingtalk": PlatformAdapterType.DINGTALK,
|
||||||
|
"discord": PlatformAdapterType.DISCORD,
|
||||||
|
"slack": PlatformAdapterType.SLACK,
|
||||||
|
"kook": PlatformAdapterType.KOOK,
|
||||||
|
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
||||||
|
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||||
|
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
from ..star import star_registry, StarMetadata, star_map
|
import warnings
|
||||||
|
|
||||||
|
from astrbot.core.star import StarMetadata, star_map
|
||||||
|
|
||||||
|
_warned_register_star = False
|
||||||
|
|
||||||
|
|
||||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
||||||
"""注册一个插件(Star)。
|
"""注册一个插件(Star)。
|
||||||
|
|
||||||
|
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||||
|
在 v3.5.19 版本之后(不含),您不需要使用该装饰器来装饰插件类,
|
||||||
|
AstrBot 会自动识别继承自 Star 的类并将其作为插件类加载。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 插件名称。
|
name: 插件名称。
|
||||||
author: 作者。
|
author: 作者。
|
||||||
@@ -21,18 +29,32 @@ def register_star(name: str, author: str, desc: str, version: str, repo: str = N
|
|||||||
帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。`
|
帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls):
|
global _warned_register_star
|
||||||
star_metadata = StarMetadata(
|
if not _warned_register_star:
|
||||||
name=name,
|
_warned_register_star = True
|
||||||
author=author,
|
warnings.warn(
|
||||||
desc=desc,
|
"The 'register_star' decorator is deprecated and will be removed in a future version.",
|
||||||
version=version,
|
DeprecationWarning,
|
||||||
repo=repo,
|
stacklevel=2,
|
||||||
star_cls_type=cls,
|
|
||||||
module_path=cls.__module__,
|
|
||||||
)
|
)
|
||||||
star_registry.append(star_metadata)
|
|
||||||
star_map[cls.__module__] = star_metadata
|
def decorator(cls):
|
||||||
|
if not star_map.get(cls.__module__):
|
||||||
|
metadata = StarMetadata(
|
||||||
|
name=name,
|
||||||
|
author=author,
|
||||||
|
desc=desc,
|
||||||
|
version=version,
|
||||||
|
repo=repo,
|
||||||
|
)
|
||||||
|
star_map[cls.__module__] = metadata
|
||||||
|
else:
|
||||||
|
star_map[cls.__module__].name = name
|
||||||
|
star_map[cls.__module__].author = author
|
||||||
|
star_map[cls.__module__].desc = desc
|
||||||
|
star_map[cls.__module__].version = version
|
||||||
|
star_map[cls.__module__].repo = repo
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|||||||
293
astrbot/core/star/session_llm_manager.py
Normal file
293
astrbot/core/star/session_llm_manager.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from astrbot.core import logger, sp
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
|
class SessionServiceManager:
|
||||||
|
"""管理会话级别的服务启停状态,包括LLM和TTS"""
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LLM 相关方法
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||||
|
"""检查LLM是否在指定会话中启用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取会话服务配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
session_services = session_config.get(session_id, {})
|
||||||
|
|
||||||
|
# 如果配置了该会话的LLM状态,返回该状态
|
||||||
|
llm_enabled = session_services.get("llm_enabled")
|
||||||
|
if llm_enabled is not None:
|
||||||
|
return llm_enabled
|
||||||
|
|
||||||
|
# 如果没有配置,默认为启用(兼容性考虑)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||||
|
"""设置LLM在指定会话中的启停状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
enabled: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取当前配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
if session_id not in session_config:
|
||||||
|
session_config[session_id] = {}
|
||||||
|
|
||||||
|
# 设置LLM状态
|
||||||
|
session_config[session_id]["llm_enabled"] = enabled
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
sp.put("session_service_config", session_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||||
|
"""检查是否应该处理LLM请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: 消息事件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示应该处理,False表示跳过
|
||||||
|
"""
|
||||||
|
session_id = event.unified_msg_origin
|
||||||
|
return SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TTS 相关方法
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||||
|
"""检查TTS是否在指定会话中启用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取会话服务配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
session_services = session_config.get(session_id, {})
|
||||||
|
|
||||||
|
# 如果配置了该会话的TTS状态,返回该状态
|
||||||
|
tts_enabled = session_services.get("tts_enabled")
|
||||||
|
if tts_enabled is not None:
|
||||||
|
return tts_enabled
|
||||||
|
|
||||||
|
# 如果没有配置,默认为启用(兼容性考虑)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||||
|
"""设置TTS在指定会话中的启停状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
enabled: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取当前配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
if session_id not in session_config:
|
||||||
|
session_config[session_id] = {}
|
||||||
|
|
||||||
|
# 设置TTS状态
|
||||||
|
session_config[session_id]["tts_enabled"] = enabled
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
sp.put("session_service_config", session_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||||
|
"""检查是否应该处理TTS请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: 消息事件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示应该处理,False表示跳过
|
||||||
|
"""
|
||||||
|
session_id = event.unified_msg_origin
|
||||||
|
return SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 会话整体启停相关方法
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_session_enabled(session_id: str) -> bool:
|
||||||
|
"""检查会话是否整体启用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取会话服务配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
session_services = session_config.get(session_id, {})
|
||||||
|
|
||||||
|
# 如果配置了该会话的整体状态,返回该状态
|
||||||
|
session_enabled = session_services.get("session_enabled")
|
||||||
|
if session_enabled is not None:
|
||||||
|
return session_enabled
|
||||||
|
|
||||||
|
# 如果没有配置,默认为启用(兼容性考虑)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_session_status(session_id: str, enabled: bool) -> None:
|
||||||
|
"""设置会话的整体启停状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
enabled: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取当前配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
if session_id not in session_config:
|
||||||
|
session_config[session_id] = {}
|
||||||
|
|
||||||
|
# 设置会话整体状态
|
||||||
|
session_config[session_id]["session_enabled"] = enabled
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
sp.put("session_service_config", session_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
||||||
|
"""检查是否应该处理会话请求(会话整体启停检查)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: 消息事件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示应该处理,False表示跳过
|
||||||
|
"""
|
||||||
|
session_id = event.unified_msg_origin
|
||||||
|
return SessionServiceManager.is_session_enabled(session_id)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 会话命名相关方法
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_session_custom_name(session_id: str) -> str:
|
||||||
|
"""获取会话的自定义名称
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 自定义名称,如果没有设置则返回None
|
||||||
|
"""
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
session_services = session_config.get(session_id, {})
|
||||||
|
return session_services.get("custom_name")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
||||||
|
"""设置会话的自定义名称
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||||
|
"""
|
||||||
|
# 获取当前配置
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
if session_id not in session_config:
|
||||||
|
session_config[session_id] = {}
|
||||||
|
|
||||||
|
# 设置自定义名称
|
||||||
|
if custom_name and custom_name.strip():
|
||||||
|
session_config[session_id]["custom_name"] = custom_name.strip()
|
||||||
|
else:
|
||||||
|
# 如果传入空名称,则删除自定义名称
|
||||||
|
session_config[session_id].pop("custom_name", None)
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
sp.put("session_service_config", session_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_session_display_name(session_id: str) -> str:
|
||||||
|
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 显示名称
|
||||||
|
"""
|
||||||
|
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||||
|
if custom_name:
|
||||||
|
return custom_name
|
||||||
|
|
||||||
|
# 如果没有自定义名称,返回session_id的最后一段
|
||||||
|
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 通用配置方法
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_session_service_config(session_id: str) -> Dict[str, bool]:
|
||||||
|
"""获取指定会话的服务配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
|
||||||
|
"""
|
||||||
|
session_config = sp.get("session_service_config", {}) or {}
|
||||||
|
return session_config.get(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"session_enabled": True, # 默认启用
|
||||||
|
"llm_enabled": True, # 默认启用
|
||||||
|
"tts_enabled": True, # 默认启用
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
|
||||||
|
"""获取所有会话的服务配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Dict[str, bool]]: 所有会话的服务配置
|
||||||
|
"""
|
||||||
|
return sp.get("session_service_config", {}) or {}
|
||||||
142
astrbot/core/star/session_plugin_manager.py
Normal file
142
astrbot/core/star/session_plugin_manager.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
会话插件管理器 - 负责管理每个会话的插件启停状态
|
||||||
|
"""
|
||||||
|
|
||||||
|
from astrbot.core import sp, logger
|
||||||
|
from typing import Dict, List
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
|
class SessionPluginManager:
|
||||||
|
"""管理会话级别的插件启停状态"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||||
|
"""检查插件是否在指定会话中启用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
plugin_name: 插件名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取会话插件配置
|
||||||
|
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||||
|
session_config = session_plugin_config.get(session_id, {})
|
||||||
|
|
||||||
|
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||||
|
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||||
|
|
||||||
|
# 如果插件在禁用列表中,返回False
|
||||||
|
if plugin_name in disabled_plugins:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 如果插件在启用列表中,返回True
|
||||||
|
if plugin_name in enabled_plugins:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_plugin_status_for_session(
|
||||||
|
session_id: str, plugin_name: str, enabled: bool
|
||||||
|
) -> None:
|
||||||
|
"""设置插件在指定会话中的启停状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
plugin_name: 插件名称
|
||||||
|
enabled: True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
# 获取当前配置
|
||||||
|
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||||
|
if session_id not in session_plugin_config:
|
||||||
|
session_plugin_config[session_id] = {
|
||||||
|
"enabled_plugins": [],
|
||||||
|
"disabled_plugins": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
session_config = session_plugin_config[session_id]
|
||||||
|
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||||
|
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
# 启用插件
|
||||||
|
if plugin_name in disabled_plugins:
|
||||||
|
disabled_plugins.remove(plugin_name)
|
||||||
|
if plugin_name not in enabled_plugins:
|
||||||
|
enabled_plugins.append(plugin_name)
|
||||||
|
else:
|
||||||
|
# 禁用插件
|
||||||
|
if plugin_name in enabled_plugins:
|
||||||
|
enabled_plugins.remove(plugin_name)
|
||||||
|
if plugin_name not in disabled_plugins:
|
||||||
|
disabled_plugins.append(plugin_name)
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
session_config["enabled_plugins"] = enabled_plugins
|
||||||
|
session_config["disabled_plugins"] = disabled_plugins
|
||||||
|
session_plugin_config[session_id] = session_config
|
||||||
|
sp.put("session_plugin_config", session_plugin_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]:
|
||||||
|
"""获取指定会话的插件配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID (unified_msg_origin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||||
|
"""
|
||||||
|
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||||
|
return session_plugin_config.get(
|
||||||
|
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List:
|
||||||
|
"""根据会话配置过滤处理器列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: 消息事件
|
||||||
|
handlers: 原始处理器列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: 过滤后的处理器列表
|
||||||
|
"""
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
|
||||||
|
session_id = event.unified_msg_origin
|
||||||
|
filtered_handlers = []
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
# 获取处理器对应的插件
|
||||||
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
|
if not plugin:
|
||||||
|
# 如果找不到插件元数据,允许执行(可能是系统插件)
|
||||||
|
filtered_handlers.append(handler)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 跳过保留插件(系统插件)
|
||||||
|
if plugin.reserved:
|
||||||
|
filtered_handlers.append(handler)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查插件是否在当前会话中启用
|
||||||
|
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||||
|
session_id, plugin.name
|
||||||
|
):
|
||||||
|
filtered_handlers.append(handler)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return filtered_handlers
|
||||||
@@ -1,14 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from types import ModuleType
|
|
||||||
from typing import List, Dict
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
|
||||||
star_registry: List[StarMetadata] = []
|
star_registry: list[StarMetadata] = []
|
||||||
star_map: Dict[str, StarMetadata] = {}
|
star_map: dict[str, StarMetadata] = {}
|
||||||
"""key 是模块路径,__module__"""
|
"""key 是模块路径,__module__"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import Star
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StarMetadata:
|
class StarMetadata:
|
||||||
@@ -18,22 +22,27 @@ class StarMetadata:
|
|||||||
当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。
|
当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str | None = None
|
||||||
author: str # 插件作者
|
"""插件名"""
|
||||||
desc: str # 插件简介
|
author: str | None = None
|
||||||
version: str # 插件版本
|
"""插件作者"""
|
||||||
repo: str = None # 插件仓库地址
|
desc: str | None = None
|
||||||
|
"""插件简介"""
|
||||||
|
version: str | None = None
|
||||||
|
"""插件版本"""
|
||||||
|
repo: str | None = None
|
||||||
|
"""插件仓库地址"""
|
||||||
|
|
||||||
star_cls_type: type = None
|
star_cls_type: type[Star] | None = None
|
||||||
"""插件的类对象的类型"""
|
"""插件的类对象的类型"""
|
||||||
module_path: str = None
|
module_path: str | None = None
|
||||||
"""插件的模块路径"""
|
"""插件的模块路径"""
|
||||||
|
|
||||||
star_cls: object = None
|
star_cls: Star | None = None
|
||||||
"""插件的类对象"""
|
"""插件的类对象"""
|
||||||
module: ModuleType = None
|
module: ModuleType | None = None
|
||||||
"""插件的模块对象"""
|
"""插件的模块对象"""
|
||||||
root_dir_name: str = None
|
root_dir_name: str | None = None
|
||||||
"""插件的目录名称"""
|
"""插件的目录名称"""
|
||||||
reserved: bool = False
|
reserved: bool = False
|
||||||
"""是否是 AstrBot 的保留插件"""
|
"""是否是 AstrBot 的保留插件"""
|
||||||
@@ -41,17 +50,20 @@ class StarMetadata:
|
|||||||
activated: bool = True
|
activated: bool = True
|
||||||
"""是否被激活"""
|
"""是否被激活"""
|
||||||
|
|
||||||
config: AstrBotConfig = None
|
config: AstrBotConfig | None = None
|
||||||
"""插件配置"""
|
"""插件配置"""
|
||||||
|
|
||||||
star_handler_full_names: List[str] = field(default_factory=list)
|
star_handler_full_names: list[str] = field(default_factory=list)
|
||||||
"""注册的 Handler 的全名列表"""
|
"""注册的 Handler 的全名列表"""
|
||||||
|
|
||||||
supported_platforms: Dict[str, bool] = field(default_factory=dict)
|
supported_platforms: dict[str, bool] = field(default_factory=dict)
|
||||||
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||||
|
|
||||||
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||||
"""更新插件支持的平台列表
|
"""更新插件支持的平台列表
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .star import star_map
|
|||||||
|
|
||||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||||
|
|
||||||
|
|
||||||
class StarHandlerRegistry(Generic[T]):
|
class StarHandlerRegistry(Generic[T]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||||
@@ -49,7 +50,8 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
self, module_name: str
|
self, module_name: str
|
||||||
) -> List[StarHandlerMetadata]:
|
) -> List[StarHandlerMetadata]:
|
||||||
return [
|
return [
|
||||||
handler for handler in self._handlers
|
handler
|
||||||
|
for handler in self._handlers
|
||||||
if handler.handler_module_path == module_name
|
if handler.handler_module_path == module_name
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -67,6 +69,7 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._handlers)
|
return len(self._handlers)
|
||||||
|
|
||||||
|
|
||||||
star_handlers_registry = StarHandlerRegistry()
|
star_handlers_registry = StarHandlerRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@@ -37,12 +36,6 @@ except ImportError:
|
|||||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||||
|
|
||||||
try:
|
|
||||||
import nh3
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
|
||||||
nh3 = None
|
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
def __init__(self, context: Context, config: AstrBotConfig):
|
def __init__(self, context: Context, config: AstrBotConfig):
|
||||||
@@ -64,6 +57,8 @@ class PluginManager:
|
|||||||
"""保留插件的路径。在 packages 目录下"""
|
"""保留插件的路径。在 packages 目录下"""
|
||||||
self.conf_schema_fname = "_conf_schema.json"
|
self.conf_schema_fname = "_conf_schema.json"
|
||||||
"""插件配置 Schema 文件名"""
|
"""插件配置 Schema 文件名"""
|
||||||
|
self._pm_lock = asyncio.Lock()
|
||||||
|
"""StarManager操作互斥锁"""
|
||||||
|
|
||||||
self.failed_plugin_info = ""
|
self.failed_plugin_info = ""
|
||||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||||
@@ -119,7 +114,8 @@ class PluginManager:
|
|||||||
reloaded_plugins.add(plugin_name)
|
reloaded_plugins.add(plugin_name)
|
||||||
break
|
break
|
||||||
|
|
||||||
def _get_classes(self, arg: ModuleType):
|
@staticmethod
|
||||||
|
def _get_classes(arg: ModuleType):
|
||||||
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
||||||
classes = []
|
classes = []
|
||||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||||
@@ -129,7 +125,8 @@ class PluginManager:
|
|||||||
break
|
break
|
||||||
return classes
|
return classes
|
||||||
|
|
||||||
def _get_modules(self, path):
|
@staticmethod
|
||||||
|
def _get_modules(path):
|
||||||
modules = []
|
modules = []
|
||||||
|
|
||||||
dirs = os.listdir(path)
|
dirs = os.listdir(path)
|
||||||
@@ -155,7 +152,7 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
return modules
|
return modules
|
||||||
|
|
||||||
def _get_plugin_modules(self) -> List[dict]:
|
def _get_plugin_modules(self) -> list[dict]:
|
||||||
plugins = []
|
plugins = []
|
||||||
if os.path.exists(self.plugin_store_path):
|
if os.path.exists(self.plugin_store_path):
|
||||||
plugins.extend(self._get_modules(self.plugin_store_path))
|
plugins.extend(self._get_modules(self.plugin_store_path))
|
||||||
@@ -166,7 +163,7 @@ class PluginManager:
|
|||||||
plugins.extend(_p)
|
plugins.extend(_p)
|
||||||
return plugins
|
return plugins
|
||||||
|
|
||||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
async def _check_plugin_dept_update(self, target_plugin: str | None = None):
|
||||||
"""检查插件的依赖
|
"""检查插件的依赖
|
||||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||||
"""
|
"""
|
||||||
@@ -189,10 +186,11 @@ class PluginManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||||
|
|
||||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj=None) -> StarMetadata:
|
@staticmethod
|
||||||
"""v3.4.0 以前的方式载入插件元数据
|
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
|
||||||
|
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||||
|
|
||||||
先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
|
||||||
"""
|
"""
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|
||||||
@@ -204,11 +202,14 @@ class PluginManager:
|
|||||||
os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8"
|
os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8"
|
||||||
) as f:
|
) as f:
|
||||||
metadata = yaml.safe_load(f)
|
metadata = yaml.safe_load(f)
|
||||||
elif plugin_obj:
|
elif plugin_obj and hasattr(plugin_obj, "info"):
|
||||||
# 使用 info() 函数
|
# 使用 info() 函数
|
||||||
metadata = plugin_obj.info()
|
metadata = plugin_obj.info()
|
||||||
|
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
|
if "desc" not in metadata and "description" in metadata:
|
||||||
|
metadata["desc"] = metadata["description"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"name" not in metadata
|
"name" not in metadata
|
||||||
or "desc" not in metadata
|
or "desc" not in metadata
|
||||||
@@ -228,8 +229,9 @@ class PluginManager:
|
|||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _get_plugin_related_modules(
|
def _get_plugin_related_modules(
|
||||||
self, plugin_root_dir: str, is_reserved: bool = False
|
plugin_root_dir: str, is_reserved: bool = False
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""获取与指定插件相关的所有已加载模块名
|
"""获取与指定插件相关的所有已加载模块名
|
||||||
|
|
||||||
@@ -251,8 +253,8 @@ class PluginManager:
|
|||||||
|
|
||||||
def _purge_modules(
|
def _purge_modules(
|
||||||
self,
|
self,
|
||||||
module_patterns: list[str] = None,
|
module_patterns: list[str] | None = None,
|
||||||
root_dir_name: str = None,
|
root_dir_name: str | None = None,
|
||||||
is_reserved: bool = False,
|
is_reserved: bool = False,
|
||||||
):
|
):
|
||||||
"""从 sys.modules 中移除指定的模块
|
"""从 sys.modules 中移除指定的模块
|
||||||
@@ -293,50 +295,51 @@ class PluginManager:
|
|||||||
- success (bool): 重载是否成功
|
- success (bool): 重载是否成功
|
||||||
- error_message (str|None): 错误信息,成功时为 None
|
- error_message (str|None): 错误信息,成功时为 None
|
||||||
"""
|
"""
|
||||||
specified_module_path = None
|
async with self._pm_lock:
|
||||||
if specified_plugin_name:
|
specified_module_path = None
|
||||||
for smd in star_registry:
|
if specified_plugin_name:
|
||||||
if smd.name == specified_plugin_name:
|
for smd in star_registry:
|
||||||
specified_module_path = smd.module_path
|
if smd.name == specified_plugin_name:
|
||||||
break
|
specified_module_path = smd.module_path
|
||||||
|
break
|
||||||
|
|
||||||
# 终止插件
|
# 终止插件
|
||||||
if not specified_module_path:
|
if not specified_module_path:
|
||||||
# 重载所有插件
|
# 重载所有插件
|
||||||
for smd in star_registry:
|
for smd in star_registry:
|
||||||
try:
|
try:
|
||||||
await self._terminate_plugin(smd)
|
await self._terminate_plugin(smd)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(traceback.format_exc())
|
logger.warning(traceback.format_exc())
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||||
)
|
)
|
||||||
|
if smd.name and smd.module_path:
|
||||||
|
await self._unbind_plugin(smd.name, smd.module_path)
|
||||||
|
|
||||||
await self._unbind_plugin(smd.name, smd.module_path)
|
star_handlers_registry.clear()
|
||||||
|
star_map.clear()
|
||||||
|
star_registry.clear()
|
||||||
|
else:
|
||||||
|
# 只重载指定插件
|
||||||
|
smd = star_map.get(specified_module_path)
|
||||||
|
if smd:
|
||||||
|
try:
|
||||||
|
await self._terminate_plugin(smd)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||||
|
)
|
||||||
|
if smd.name:
|
||||||
|
await self._unbind_plugin(smd.name, specified_module_path)
|
||||||
|
|
||||||
star_handlers_registry.clear()
|
result = await self.load(specified_module_path)
|
||||||
star_map.clear()
|
|
||||||
star_registry.clear()
|
|
||||||
else:
|
|
||||||
# 只重载指定插件
|
|
||||||
smd = star_map.get(specified_module_path)
|
|
||||||
if smd:
|
|
||||||
try:
|
|
||||||
await self._terminate_plugin(smd)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(traceback.format_exc())
|
|
||||||
logger.warning(
|
|
||||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._unbind_plugin(smd.name, specified_module_path)
|
# 更新所有插件的平台兼容性
|
||||||
|
await self.update_all_platform_compatibility()
|
||||||
|
|
||||||
result = await self.load(specified_module_path)
|
return result
|
||||||
|
|
||||||
# 更新所有插件的平台兼容性
|
|
||||||
await self.update_all_platform_compatibility()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def update_all_platform_compatibility(self):
|
async def update_all_platform_compatibility(self):
|
||||||
"""更新所有插件的平台兼容性设置"""
|
"""更新所有插件的平台兼容性设置"""
|
||||||
@@ -435,7 +438,7 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if path in star_map:
|
if path in star_map:
|
||||||
# 通过装饰器的方式注册插件
|
# 通过 __init__subclass__ 注册插件
|
||||||
metadata = star_map[path]
|
metadata = star_map[path]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -449,13 +452,15 @@ class PluginManager:
|
|||||||
metadata.desc = metadata_yaml.desc
|
metadata.desc = metadata_yaml.desc
|
||||||
metadata.version = metadata_yaml.version
|
metadata.version = metadata_yaml.version
|
||||||
metadata.repo = metadata_yaml.repo
|
metadata.repo = metadata_yaml.repo
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(
|
||||||
|
f"插件 {root_dir_name} 元数据载入失败: {str(e)}。使用默认元数据。"
|
||||||
|
)
|
||||||
|
logger.info(metadata)
|
||||||
metadata.config = plugin_config
|
metadata.config = plugin_config
|
||||||
if path not in inactivated_plugins:
|
if path not in inactivated_plugins:
|
||||||
# 只有没有禁用插件时才实例化插件类
|
# 只有没有禁用插件时才实例化插件类
|
||||||
if plugin_config:
|
if plugin_config and metadata.star_cls_type:
|
||||||
# metadata.config = plugin_config
|
|
||||||
try:
|
try:
|
||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
context=self.context, config=plugin_config
|
context=self.context, config=plugin_config
|
||||||
@@ -464,7 +469,7 @@ class PluginManager:
|
|||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
context=self.context
|
context=self.context
|
||||||
)
|
)
|
||||||
else:
|
elif metadata.star_cls_type:
|
||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
context=self.context
|
context=self.context
|
||||||
)
|
)
|
||||||
@@ -481,6 +486,10 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
metadata.update_platform_compatibility(plugin_enable_config)
|
metadata.update_platform_compatibility(plugin_enable_config)
|
||||||
|
|
||||||
|
assert metadata.module_path is not None, (
|
||||||
|
f"插件 {metadata.name} 的模块路径为空。"
|
||||||
|
)
|
||||||
|
|
||||||
# 绑定 handler
|
# 绑定 handler
|
||||||
related_handlers = (
|
related_handlers = (
|
||||||
star_handlers_registry.get_handlers_by_module_name(
|
star_handlers_registry.get_handlers_by_module_name(
|
||||||
@@ -489,7 +498,8 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
for handler in related_handlers:
|
for handler in related_handlers:
|
||||||
handler.handler = functools.partial(
|
handler.handler = functools.partial(
|
||||||
handler.handler, metadata.star_cls
|
handler.handler,
|
||||||
|
metadata.star_cls, # type: ignore
|
||||||
)
|
)
|
||||||
# 绑定 llm_tool handler
|
# 绑定 llm_tool handler
|
||||||
for func_tool in llm_tools.func_list:
|
for func_tool in llm_tools.func_list:
|
||||||
@@ -499,7 +509,8 @@ class PluginManager:
|
|||||||
):
|
):
|
||||||
func_tool.handler_module_path = metadata.module_path
|
func_tool.handler_module_path = metadata.module_path
|
||||||
func_tool.handler = functools.partial(
|
func_tool.handler = functools.partial(
|
||||||
func_tool.handler, metadata.star_cls
|
func_tool.handler,
|
||||||
|
metadata.star_cls, # type: ignore
|
||||||
)
|
)
|
||||||
if func_tool.name in inactivated_llm_tools:
|
if func_tool.name in inactivated_llm_tools:
|
||||||
func_tool.active = False
|
func_tool.active = False
|
||||||
@@ -526,13 +537,12 @@ class PluginManager:
|
|||||||
obj = getattr(module, classes[0])(
|
obj = getattr(module, classes[0])(
|
||||||
context=self.context
|
context=self.context
|
||||||
) # 实例化插件类
|
) # 实例化插件类
|
||||||
else:
|
|
||||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
|
||||||
|
|
||||||
metadata = None
|
|
||||||
metadata = self._load_plugin_metadata(
|
metadata = self._load_plugin_metadata(
|
||||||
plugin_path=plugin_dir_path, plugin_obj=obj
|
plugin_path=plugin_dir_path, plugin_obj=obj
|
||||||
)
|
)
|
||||||
|
if not metadata:
|
||||||
|
raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
|
||||||
metadata.star_cls = obj
|
metadata.star_cls = obj
|
||||||
metadata.config = plugin_config
|
metadata.config = plugin_config
|
||||||
metadata.module = module
|
metadata.module = module
|
||||||
@@ -547,6 +557,10 @@ class PluginManager:
|
|||||||
if metadata.module_path in inactivated_plugins:
|
if metadata.module_path in inactivated_plugins:
|
||||||
metadata.activated = False
|
metadata.activated = False
|
||||||
|
|
||||||
|
assert metadata.module_path is not None, (
|
||||||
|
f"插件 {metadata.name} 的模块路径为空。"
|
||||||
|
)
|
||||||
|
|
||||||
full_names = []
|
full_names = []
|
||||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||||
metadata.module_path
|
metadata.module_path
|
||||||
@@ -586,7 +600,7 @@ class PluginManager:
|
|||||||
metadata.star_handler_full_names = full_names
|
metadata.star_handler_full_names = full_names
|
||||||
|
|
||||||
# 执行 initialize() 方法
|
# 执行 initialize() 方法
|
||||||
if hasattr(metadata.star_cls, "initialize"):
|
if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
|
||||||
await metadata.star_cls.initialize()
|
await metadata.star_cls.initialize()
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
@@ -622,43 +636,45 @@ class PluginManager:
|
|||||||
- readme: README.md 文件的内容(如果存在)
|
- readme: README.md 文件的内容(如果存在)
|
||||||
如果找不到插件元数据则返回 None。
|
如果找不到插件元数据则返回 None。
|
||||||
"""
|
"""
|
||||||
plugin_path = await self.updator.install(repo_url, proxy)
|
async with self._pm_lock:
|
||||||
# reload the plugin
|
plugin_path = await self.updator.install(repo_url, proxy)
|
||||||
dir_name = os.path.basename(plugin_path)
|
# reload the plugin
|
||||||
await self.load(specified_dir_name=dir_name)
|
dir_name = os.path.basename(plugin_path)
|
||||||
|
await self.load(specified_dir_name=dir_name)
|
||||||
|
|
||||||
# Get the plugin metadata to return repo info
|
# Get the plugin metadata to return repo info
|
||||||
plugin = self.context.get_registered_star(dir_name)
|
plugin = self.context.get_registered_star(dir_name)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
# Try to find by other name if directory name doesn't match plugin name
|
# Try to find by other name if directory name doesn't match plugin name
|
||||||
for star in self.context.get_all_stars():
|
for star in self.context.get_all_stars():
|
||||||
if star.root_dir_name == dir_name:
|
if star.root_dir_name == dir_name:
|
||||||
plugin = star
|
plugin = star
|
||||||
break
|
break
|
||||||
|
|
||||||
# Extract README.md content if exists
|
# Extract README.md content if exists
|
||||||
readme_content = None
|
readme_content = None
|
||||||
readme_path = os.path.join(plugin_path, "README.md")
|
readme_path = os.path.join(plugin_path, "README.md")
|
||||||
if not os.path.exists(readme_path):
|
if not os.path.exists(readme_path):
|
||||||
readme_path = os.path.join(plugin_path, "readme.md")
|
readme_path = os.path.join(plugin_path, "readme.md")
|
||||||
|
|
||||||
if os.path.exists(readme_path) and nh3:
|
if os.path.exists(readme_path):
|
||||||
try:
|
try:
|
||||||
with open(readme_path, "r", encoding="utf-8") as f:
|
with open(readme_path, "r", encoding="utf-8") as f:
|
||||||
readme_content = f.read()
|
readme_content = f.read()
|
||||||
cleaned_content = nh3.clean(readme_content)
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.warning(
|
||||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
plugin_info = None
|
plugin_info = None
|
||||||
if plugin:
|
if plugin:
|
||||||
plugin_info = {
|
plugin_info = {
|
||||||
"repo": plugin.repo,
|
"repo": plugin.repo,
|
||||||
"readme": cleaned_content,
|
"readme": readme_content,
|
||||||
"name": plugin.name,
|
"name": plugin.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
return plugin_info
|
return plugin_info
|
||||||
|
|
||||||
async def uninstall_plugin(self, plugin_name: str):
|
async def uninstall_plugin(self, plugin_name: str):
|
||||||
"""卸载指定的插件。
|
"""卸载指定的插件。
|
||||||
@@ -669,32 +685,33 @@ class PluginManager:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||||
"""
|
"""
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
async with self._pm_lock:
|
||||||
if not plugin:
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
raise Exception("插件不存在。")
|
if not plugin:
|
||||||
if plugin.reserved:
|
raise Exception("插件不存在。")
|
||||||
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
if plugin.reserved:
|
||||||
root_dir_name = plugin.root_dir_name
|
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
||||||
ppath = self.plugin_store_path
|
root_dir_name = plugin.root_dir_name
|
||||||
|
ppath = self.plugin_store_path
|
||||||
|
|
||||||
# 终止插件
|
# 终止插件
|
||||||
try:
|
try:
|
||||||
await self._terminate_plugin(plugin)
|
await self._terminate_plugin(plugin)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(traceback.format_exc())
|
logger.warning(traceback.format_exc())
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
|
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从 star_registry 和 star_map 中删除
|
# 从 star_registry 和 star_map 中删除
|
||||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
remove_dir(os.path.join(ppath, root_dir_name))
|
remove_dir(os.path.join(ppath, root_dir_name))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||||
"""解绑并移除一个插件。
|
"""解绑并移除一个插件。
|
||||||
@@ -725,6 +742,9 @@ class PluginManager:
|
|||||||
]:
|
]:
|
||||||
del star_handlers_registry.star_handlers_map[k]
|
del star_handlers_registry.star_handlers_map[k]
|
||||||
|
|
||||||
|
if plugin is None:
|
||||||
|
return
|
||||||
|
|
||||||
self._purge_modules(
|
self._purge_modules(
|
||||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||||
)
|
)
|
||||||
@@ -747,35 +767,37 @@ class PluginManager:
|
|||||||
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
|
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
|
||||||
并且同时将插件启用的 llm_tool 禁用。
|
并且同时将插件启用的 llm_tool 禁用。
|
||||||
"""
|
"""
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
async with self._pm_lock:
|
||||||
if not plugin:
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
raise Exception("插件不存在。")
|
if not plugin:
|
||||||
|
raise Exception("插件不存在。")
|
||||||
|
|
||||||
# 调用插件的终止方法
|
# 调用插件的终止方法
|
||||||
await self._terminate_plugin(plugin)
|
await self._terminate_plugin(plugin)
|
||||||
|
|
||||||
# 加入到 shared_preferences 中
|
# 加入到 shared_preferences 中
|
||||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||||
if plugin.module_path not in inactivated_plugins:
|
if plugin.module_path not in inactivated_plugins:
|
||||||
inactivated_plugins.append(plugin.module_path)
|
inactivated_plugins.append(plugin.module_path)
|
||||||
|
|
||||||
inactivated_llm_tools: list = list(
|
inactivated_llm_tools: list = list(
|
||||||
set(sp.get("inactivated_llm_tools", []))
|
set(sp.get("inactivated_llm_tools", []))
|
||||||
) # 后向兼容
|
) # 后向兼容
|
||||||
|
|
||||||
# 禁用插件启用的 llm_tool
|
# 禁用插件启用的 llm_tool
|
||||||
for func_tool in llm_tools.func_list:
|
for func_tool in llm_tools.func_list:
|
||||||
if func_tool.handler_module_path == plugin.module_path:
|
if func_tool.handler_module_path == plugin.module_path:
|
||||||
func_tool.active = False
|
func_tool.active = False
|
||||||
if func_tool.name not in inactivated_llm_tools:
|
if func_tool.name not in inactivated_llm_tools:
|
||||||
inactivated_llm_tools.append(func_tool.name)
|
inactivated_llm_tools.append(func_tool.name)
|
||||||
|
|
||||||
sp.put("inactivated_plugins", inactivated_plugins)
|
sp.put("inactivated_plugins", inactivated_plugins)
|
||||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||||
|
|
||||||
plugin.activated = False
|
plugin.activated = False
|
||||||
|
|
||||||
async def _terminate_plugin(self, star_metadata: StarMetadata):
|
@staticmethod
|
||||||
|
async def _terminate_plugin(star_metadata: StarMetadata):
|
||||||
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
||||||
logger.info(f"正在终止插件 {star_metadata.name} ...")
|
logger.info(f"正在终止插件 {star_metadata.name} ...")
|
||||||
|
|
||||||
@@ -784,11 +806,14 @@ class PluginManager:
|
|||||||
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
||||||
return
|
return
|
||||||
|
|
||||||
if hasattr(star_metadata.star_cls, "__del__"):
|
if star_metadata.star_cls is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||||
asyncio.get_event_loop().run_in_executor(
|
asyncio.get_event_loop().run_in_executor(
|
||||||
None, star_metadata.star_cls.__del__
|
None, star_metadata.star_cls.__del__
|
||||||
)
|
)
|
||||||
elif hasattr(star_metadata.star_cls, "terminate"):
|
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
async def turn_on_plugin(self, plugin_name: str):
|
async def turn_on_plugin(self, plugin_name: str):
|
||||||
|
|||||||
@@ -182,7 +182,9 @@ class StarTools:
|
|||||||
|
|
||||||
plugin_name = metadata.name
|
plugin_name = metadata.name
|
||||||
|
|
||||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
data_dir = Path(
|
||||||
|
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@@ -56,9 +56,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
|||||||
try:
|
try:
|
||||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
args = [
|
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
|
||||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
|
|||||||
raise exc_info[1]
|
raise exc_info[1]
|
||||||
|
|
||||||
|
|
||||||
def remove_dir(file_path) -> bool:
|
def remove_dir(file_path: str) -> bool:
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
return True
|
return True
|
||||||
shutil.rmtree(file_path, onerror=on_error)
|
shutil.rmtree(file_path, onerror=on_error)
|
||||||
|
|||||||
29
astrbot/core/utils/session_lock.py
Normal file
29
astrbot/core/utils/session_lock.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
|
class SessionLockManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
|
self._lock_count: dict[str, int] = defaultdict(int)
|
||||||
|
self._access_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def acquire_lock(self, session_id: str):
|
||||||
|
async with self._access_lock:
|
||||||
|
lock = self._locks[session_id]
|
||||||
|
self._lock_count[session_id] += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
async with self._access_lock:
|
||||||
|
self._lock_count[session_id] -= 1
|
||||||
|
if self._lock_count[session_id] == 0:
|
||||||
|
self._locks.pop(session_id, None)
|
||||||
|
self._lock_count.pop(session_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
session_lock_manager = SessionLockManager()
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import TypeVar
|
||||||
from .astrbot_path import get_astrbot_data_path
|
from .astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
|
|
||||||
class SharedPreferences:
|
class SharedPreferences:
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None):
|
||||||
@@ -24,7 +27,7 @@ class SharedPreferences:
|
|||||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default: _VT = None) -> _VT:
|
||||||
return self._data.get(key, default)
|
return self._data.get(key, default)
|
||||||
|
|
||||||
def put(self, key, value):
|
def put(self, key, value):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
|||||||
|
|
||||||
|
|
||||||
class NetworkRenderStrategy(RenderStrategy):
|
class NetworkRenderStrategy(RenderStrategy):
|
||||||
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
|
def __init__(self, base_url: str | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||||
@@ -34,18 +34,22 @@ class NetworkRenderStrategy(RenderStrategy):
|
|||||||
self.BASE_RENDER_URL += "/text2img"
|
self.BASE_RENDER_URL += "/text2img"
|
||||||
|
|
||||||
async def render_custom_template(
|
async def render_custom_template(
|
||||||
self, tmpl_str: str, tmpl_data: dict, return_url: bool = True
|
self,
|
||||||
|
tmpl_str: str,
|
||||||
|
tmpl_data: dict,
|
||||||
|
return_url: bool = True,
|
||||||
|
options: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""使用自定义文转图模板"""
|
"""使用自定义文转图模板"""
|
||||||
|
default_options = {"full_page": True, "type": "jpeg", "quality": 40}
|
||||||
|
if options:
|
||||||
|
default_options |= options
|
||||||
|
|
||||||
post_data = {
|
post_data = {
|
||||||
"tmpl": tmpl_str,
|
"tmpl": tmpl_str,
|
||||||
"json": return_url,
|
"json": return_url,
|
||||||
"tmpldata": tmpl_data,
|
"tmpldata": tmpl_data,
|
||||||
"options": {
|
"options": default_options,
|
||||||
"full_page": True,
|
|
||||||
"type": "jpeg",
|
|
||||||
"quality": 40,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
if return_url:
|
if return_url:
|
||||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ logger = LogManager.GetLogger(log_name="astrbot")
|
|||||||
|
|
||||||
|
|
||||||
class HtmlRenderer:
|
class HtmlRenderer:
|
||||||
def __init__(self, endpoint_url: str = None):
|
def __init__(self, endpoint_url: str | None = None):
|
||||||
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
||||||
self.local_strategy = LocalRenderStrategy()
|
self.local_strategy = LocalRenderStrategy()
|
||||||
|
|
||||||
@@ -16,19 +16,24 @@ class HtmlRenderer:
|
|||||||
self.network_strategy.set_endpoint(endpoint_url)
|
self.network_strategy.set_endpoint(endpoint_url)
|
||||||
|
|
||||||
async def render_custom_template(
|
async def render_custom_template(
|
||||||
self, tmpl_str: str, tmpl_data: dict, return_url: bool = False
|
self,
|
||||||
|
tmpl_str: str,
|
||||||
|
tmpl_data: dict,
|
||||||
|
return_url: bool = False,
|
||||||
|
options: dict | None = None,
|
||||||
):
|
):
|
||||||
"""使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
|
"""使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
|
||||||
@param tmpl_str: HTML Jinja2 模板。
|
@param tmpl_str: HTML Jinja2 模板。
|
||||||
@param tmpl_data: jinja2 模板数据。
|
@param tmpl_data: jinja2 模板数据。
|
||||||
|
@param options: 渲染选项。
|
||||||
|
|
||||||
@return: 图片 URL 或者文件路径,取决于 return_url 参数。
|
@return: 图片 URL 或者文件路径,取决于 return_url 参数。
|
||||||
|
|
||||||
@example: 参见 https://astrbot.app 插件开发部分。
|
@example: 参见 https://astrbot.app 插件开发部分。
|
||||||
"""
|
"""
|
||||||
local = locals()
|
return await self.network_strategy.render_custom_template(
|
||||||
local.pop("self")
|
tmpl_str, tmpl_data, return_url, options
|
||||||
return await self.network_strategy.render_custom_template(**local)
|
)
|
||||||
|
|
||||||
async def render_t2i(
|
async def render_t2i(
|
||||||
self, text: str, use_network: bool = True, return_url: bool = False
|
self, text: str, use_network: bool = True, return_url: bool = False
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import wave
|
import wave
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from astrbot.core import logger
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
|
||||||
@@ -57,33 +59,89 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
|||||||
return duration
|
return duration
|
||||||
|
|
||||||
|
|
||||||
async def wav_to_tencent_silk_base64(wav_path: str) -> str:
|
async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
将 WAV 文件转为 Silk,并返回 Base64 字符串。
|
将 MP3 或其他音频格式转换为 PCM 16bit WAV,采样率24000Hz,单声道。
|
||||||
默认采样率为 24000,输出临时文件为 temp/output.silk。
|
若转换失败则抛出异常。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFmpeg
|
||||||
|
|
||||||
|
ff = FFmpeg()
|
||||||
|
ff.convert(input=input_path, output=output_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||||
|
|
||||||
|
p = await asyncio.create_subprocess_exec(
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
input_path,
|
||||||
|
"-acodec",
|
||||||
|
"pcm_s16le",
|
||||||
|
"-ar",
|
||||||
|
"24000",
|
||||||
|
"-ac",
|
||||||
|
"1",
|
||||||
|
"-af",
|
||||||
|
"apad=pad_dur=2",
|
||||||
|
"-fflags",
|
||||||
|
"+genpts",
|
||||||
|
"-hide_banner",
|
||||||
|
output_path,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await p.communicate()
|
||||||
|
logger.info(f"[FFmpeg] stdout: {stdout.decode().strip()}")
|
||||||
|
logger.debug(f"[FFmpeg] stderr: {stderr.decode().strip()}")
|
||||||
|
logger.info(f"[FFmpeg] return code: {p.returncode}")
|
||||||
|
|
||||||
|
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||||||
|
return output_path
|
||||||
|
else:
|
||||||
|
raise RuntimeError("生成的WAV文件不存在或为空")
|
||||||
|
|
||||||
|
|
||||||
|
async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
|
||||||
|
"""
|
||||||
|
将 MP3/WAV 文件转为 Tencent Silk 并返回 base64 编码与时长(秒)。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- wav_path: 输入 .wav 文件路径(需为 PCM 16bit)
|
- audio_path: 输入音频文件路径(.mp3 或 .wav)
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
- Base64 编码的 Silk 字符串
|
- silk_b64: Base64 编码的 Silk 字符串
|
||||||
- duration: 音频时长(秒)
|
- duration: 音频时长(秒)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import pilk
|
import pilk
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception("pysilk 模块未安装,请安装 pysilk") from e
|
raise Exception("未安装 pilk: pip install pilk") from e
|
||||||
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
with wave.open(wav_path, "rb") as wav:
|
# 是否需要转换为 WAV
|
||||||
rate = wav.getframerate()
|
ext = os.path.splitext(audio_path)[1].lower()
|
||||||
|
temp_wav = tempfile.NamedTemporaryFile(
|
||||||
|
suffix=".wav", delete=False, dir=temp_dir
|
||||||
|
).name
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(
|
if ext != ".wav":
|
||||||
|
await convert_to_pcm_wav(audio_path, temp_wav)
|
||||||
|
# 删除原文件
|
||||||
|
os.remove(audio_path)
|
||||||
|
wav_path = temp_wav
|
||||||
|
else:
|
||||||
|
wav_path = audio_path
|
||||||
|
|
||||||
|
with wave.open(wav_path, "rb") as wav_file:
|
||||||
|
rate = wav_file.getframerate()
|
||||||
|
|
||||||
|
silk_path = tempfile.NamedTemporaryFile(
|
||||||
suffix=".silk", delete=False, dir=temp_dir
|
suffix=".silk", delete=False, dir=temp_dir
|
||||||
) as tmp_file:
|
).name
|
||||||
silk_path = tmp_file.name
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
duration = await asyncio.to_thread(
|
duration = await asyncio.to_thread(
|
||||||
@@ -96,5 +154,7 @@ async def wav_to_tencent_silk_base64(wav_path: str) -> str:
|
|||||||
|
|
||||||
return silk_b64, duration # 已是秒
|
return silk_b64, duration # 已是秒
|
||||||
finally:
|
finally:
|
||||||
|
if os.path.exists(wav_path) and wav_path != audio_path:
|
||||||
|
os.remove(wav_path)
|
||||||
if os.path.exists(silk_path):
|
if os.path.exists(silk_path):
|
||||||
os.remove(silk_path)
|
os.remove(silk_path)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .chat import ChatRoute
|
|||||||
from .tools import ToolsRoute # 导入新的ToolsRoute
|
from .tools import ToolsRoute # 导入新的ToolsRoute
|
||||||
from .conversation import ConversationRoute
|
from .conversation import ConversationRoute
|
||||||
from .file import FileRoute
|
from .file import FileRoute
|
||||||
|
from .session_management import SessionManagementRoute
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -23,4 +24,5 @@ __all__ = [
|
|||||||
"ToolsRoute",
|
"ToolsRoute",
|
||||||
"ConversationRoute",
|
"ConversationRoute",
|
||||||
"FileRoute",
|
"FileRoute",
|
||||||
|
"SessionManagementRoute",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import datetime
|
|||||||
import asyncio
|
import asyncio
|
||||||
from .route import Route, Response, RouteContext
|
from .route import Route, Response, RouteContext
|
||||||
from quart import request
|
from quart import request
|
||||||
from astrbot.core import WEBUI_SK, DEMO_MODE
|
from astrbot.core import DEMO_MODE
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -80,5 +80,8 @@ class AuthRoute(Route):
|
|||||||
"username": username,
|
"username": username,
|
||||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
||||||
}
|
}
|
||||||
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256")
|
jwt_token = self.config["dashboard"].get("jwt_secret", None)
|
||||||
|
if not jwt_token:
|
||||||
|
raise ValueError("JWT secret is not set in the cmd_config.")
|
||||||
|
token = jwt.encode(payload, jwt_token, algorithm="HS256")
|
||||||
return token
|
return token
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from .route import Route, Response, RouteContext
|
from .route import Route, Response, RouteContext
|
||||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||||
from quart import request, Response as QuartResponse, g, make_response
|
from quart import request, Response as QuartResponse, g, make_response
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -21,7 +21,6 @@ class ChatRoute(Route):
|
|||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.routes = {
|
self.routes = {
|
||||||
"/chat/send": ("POST", self.chat),
|
"/chat/send": ("POST", self.chat),
|
||||||
"/chat/listen": ("GET", self.listener),
|
|
||||||
"/chat/new_conversation": ("GET", self.new_conversation),
|
"/chat/new_conversation": ("GET", self.new_conversation),
|
||||||
"/chat/conversations": ("GET", self.get_conversations),
|
"/chat/conversations": ("GET", self.get_conversations),
|
||||||
"/chat/get_conversation": ("GET", self.get_conversation),
|
"/chat/get_conversation": ("GET", self.get_conversation),
|
||||||
@@ -40,9 +39,6 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
||||||
|
|
||||||
self.curr_user_cid = {}
|
|
||||||
self.curr_chat_sse = {}
|
|
||||||
|
|
||||||
async def status(self):
|
async def status(self):
|
||||||
has_llm_enabled = (
|
has_llm_enabled = (
|
||||||
self.core_lifecycle.provider_manager.curr_provider_inst is not None
|
self.core_lifecycle.provider_manager.curr_provider_inst is not None
|
||||||
@@ -124,6 +120,8 @@ class ChatRoute(Route):
|
|||||||
conversation_id = post_data["conversation_id"]
|
conversation_id = post_data["conversation_id"]
|
||||||
image_url = post_data.get("image_url")
|
image_url = post_data.get("image_url")
|
||||||
audio_url = post_data.get("audio_url")
|
audio_url = post_data.get("audio_url")
|
||||||
|
selected_provider = post_data.get("selected_provider")
|
||||||
|
selected_model = post_data.get("selected_model")
|
||||||
if not message and not image_url and not audio_url:
|
if not message and not image_url and not audio_url:
|
||||||
return (
|
return (
|
||||||
Response()
|
Response()
|
||||||
@@ -133,21 +131,10 @@ class ChatRoute(Route):
|
|||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
return Response().error("conversation_id is empty").__dict__
|
return Response().error("conversation_id is empty").__dict__
|
||||||
|
|
||||||
self.curr_user_cid[username] = conversation_id
|
# Get conversation-specific queues
|
||||||
|
back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id)
|
||||||
|
|
||||||
await web_chat_queue.put(
|
# append user message
|
||||||
(
|
|
||||||
username,
|
|
||||||
conversation_id,
|
|
||||||
{
|
|
||||||
"message": message,
|
|
||||||
"image_url": image_url, # list
|
|
||||||
"audio_url": audio_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 持久化
|
|
||||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||||
try:
|
try:
|
||||||
history = json.loads(conversation.history)
|
history = json.loads(conversation.history)
|
||||||
@@ -164,30 +151,12 @@ class ChatRoute(Route):
|
|||||||
username, conversation_id, history=json.dumps(history)
|
username, conversation_id, history=json.dumps(history)
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response().ok().__dict__
|
|
||||||
|
|
||||||
async def listener(self):
|
|
||||||
"""一直保持长连接"""
|
|
||||||
|
|
||||||
username = g.get("username", "guest")
|
|
||||||
|
|
||||||
if username in self.curr_chat_sse:
|
|
||||||
return Response().error("Already connected").__dict__
|
|
||||||
|
|
||||||
self.curr_chat_sse[username] = None
|
|
||||||
|
|
||||||
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
|
|
||||||
|
|
||||||
async def stream():
|
async def stream():
|
||||||
try:
|
try:
|
||||||
yield f"data: {heartbeat}\n\n" # 心跳包
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(back_queue.get(), timeout=10)
|
||||||
web_chat_back_queue.get(), timeout=10
|
|
||||||
) # 设置超时时间为5秒
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
yield f"data: {heartbeat}\n\n" # 心跳包
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
@@ -197,19 +166,13 @@ class ChatRoute(Route):
|
|||||||
type = result.get("type")
|
type = result.get("type")
|
||||||
cid = result.get("cid")
|
cid = result.get("cid")
|
||||||
streaming = result.get("streaming", False)
|
streaming = result.get("streaming", False)
|
||||||
if cid != self.curr_user_cid.get(username):
|
|
||||||
# 丢弃
|
|
||||||
continue
|
|
||||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
if streaming and type != "end":
|
if type == "end":
|
||||||
continue
|
break
|
||||||
|
elif (streaming and type == "complete") or not streaming:
|
||||||
if type == "update_title":
|
# append bot message
|
||||||
continue
|
|
||||||
|
|
||||||
if result_text:
|
|
||||||
conversation = self.db.get_conversation_by_user_id(
|
conversation = self.db.get_conversation_by_user_id(
|
||||||
username, cid
|
username, cid
|
||||||
)
|
)
|
||||||
@@ -222,11 +185,27 @@ class ChatRoute(Route):
|
|||||||
self.db.update_conversation(
|
self.db.update_conversation(
|
||||||
username, cid, history=json.dumps(history)
|
username, cid, history=json.dumps(history)
|
||||||
)
|
)
|
||||||
|
|
||||||
except BaseException as _:
|
except BaseException as _:
|
||||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||||
self.curr_chat_sse.pop(username)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Put message to conversation-specific queue
|
||||||
|
chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||||
|
await chat_queue.put(
|
||||||
|
(
|
||||||
|
username,
|
||||||
|
conversation_id,
|
||||||
|
{
|
||||||
|
"message": message,
|
||||||
|
"image_url": image_url, # list
|
||||||
|
"audio_url": audio_url,
|
||||||
|
"selected_provider": selected_provider,
|
||||||
|
"selected_model": selected_model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
response = await make_response(
|
response = await make_response(
|
||||||
stream(),
|
stream(),
|
||||||
{
|
{
|
||||||
@@ -236,7 +215,6 @@ class ChatRoute(Route):
|
|||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.timeout = None
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def delete_conversation(self):
|
async def delete_conversation(self):
|
||||||
@@ -245,6 +223,8 @@ class ChatRoute(Route):
|
|||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
return Response().error("Missing key: conversation_id").__dict__
|
return Response().error("Missing key: conversation_id").__dict__
|
||||||
|
|
||||||
|
# Clean up queues when deleting conversation
|
||||||
|
webchat_queue_mgr.remove_queues(conversation_id)
|
||||||
self.db.delete_conversation(username, conversation_id)
|
self.db.delete_conversation(username, conversation_id)
|
||||||
return Response().ok().__dict__
|
return Response().ok().__dict__
|
||||||
|
|
||||||
@@ -279,6 +259,4 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||||
|
|
||||||
self.curr_user_cid[username] = conversation_id
|
|
||||||
|
|
||||||
return Response().ok(data=conversation).__dict__
|
return Response().ok(data=conversation).__dict__
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user