Compare commits
157 Commits
v3.5.19
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6b6eef8c4 | ||
|
|
50cf263076 | ||
|
|
2554548088 | ||
|
|
aa4a2d10e2 | ||
|
|
02a9769b35 | ||
|
|
7640f11bfc | ||
|
|
9fa44dbcfa | ||
|
|
2cae941bae | ||
|
|
bc0784f41d | ||
|
|
c57d75e01a | ||
|
|
73edeae013 | ||
|
|
7d46314dc8 | ||
|
|
d5a53a89eb | ||
|
|
a85bc510dd | ||
|
|
2beea7d218 | ||
|
|
a93cd3dd5f | ||
|
|
db4d02c2e2 | ||
|
|
fd7811402b | ||
|
|
eb0325e627 | ||
|
|
8b4b04ec09 | ||
|
|
9f32c9280f | ||
|
|
4fcd09cfa8 | ||
|
|
7a8d65d37d | ||
|
|
23129a9ba2 | ||
|
|
7f791e730b | ||
|
|
f7e296b349 | ||
|
|
712d4acaaa | ||
|
|
74a5c01f21 | ||
|
|
3ba8724d77 | ||
|
|
6313a7d8a9 | ||
|
|
432a3f520c | ||
|
|
191b3e42d4 | ||
|
|
a27f05fcb4 | ||
|
|
2f33e0b873 | ||
|
|
f0359467f1 | ||
|
|
d1db8cf2c8 | ||
|
|
b1985ed2ce | ||
|
|
140ddc70e6 | ||
|
|
d7fd616470 | ||
|
|
3ccbef141e | ||
|
|
e92fbb0443 | ||
|
|
bd270aed68 | ||
|
|
28d7864393 | ||
|
|
b5d8173ee3 | ||
|
|
17d62a9af7 | ||
|
|
d89fb863ed | ||
|
|
a21ad77820 | ||
|
|
f86c8e8cab | ||
|
|
cb12cbdd3d | ||
|
|
6661fa996c | ||
|
|
c19bca798b | ||
|
|
8f98b411db | ||
|
|
a8aa03847e | ||
|
|
1bfd747cc6 | ||
|
|
ae06d945a7 | ||
|
|
9f41d5f34d | ||
|
|
ef61c52908 | ||
|
|
d8842ef274 | ||
|
|
c88fdaf353 | ||
|
|
af295da871 | ||
|
|
083235a2fe | ||
|
|
2a3a5f7eb2 | ||
|
|
77c48f280f | ||
|
|
0ee1eb2f9f | ||
|
|
c2b20365bb | ||
|
|
cfdc7e4452 | ||
|
|
2363f61aa9 | ||
|
|
557ac6f9fa | ||
|
|
a49b871cf9 | ||
|
|
a0d6b3efba | ||
|
|
6cabf07bc0 | ||
|
|
a15444ee8c | ||
|
|
ceb5f5669e | ||
|
|
25b75e05e4 | ||
|
|
4d214bb5c1 | ||
|
|
7cbaed8c6c | ||
|
|
2915fdf665 | ||
|
|
a66c385b08 | ||
|
|
4dace7c5d8 | ||
|
|
8ebf087dbf | ||
|
|
2fa8bda5bb | ||
|
|
a5ae833945 | ||
|
|
d21d42b312 | ||
|
|
78575f0f0a | ||
|
|
8ccd292d16 | ||
|
|
2534f59398 | ||
|
|
5c60dbe2b1 | ||
|
|
c99ecde15f | ||
|
|
219f3403d9 | ||
|
|
00f417bad6 | ||
|
|
81649f053b | ||
|
|
e5bde50f2d | ||
|
|
0321e00b0d | ||
|
|
09528e3292 | ||
|
|
e7412a9cbf | ||
|
|
01efe5f869 | ||
|
|
28a178a55c | ||
|
|
88f130014c | ||
|
|
af258c590c | ||
|
|
b0eb5733be | ||
|
|
fe35bfba37 | ||
|
|
7cfbc4ab8f | ||
|
|
7a9d4f0abd | ||
|
|
6f6a5b565c | ||
|
|
e57deb873c | ||
|
|
0f692b1608 | ||
|
|
8c03e79f99 | ||
|
|
71290f0929 | ||
|
|
22364ef7de | ||
|
|
2cc1eb1abc | ||
|
|
90dbcbb4e2 | ||
|
|
66503d58be | ||
|
|
8e10f0ce2b | ||
|
|
f51f510f2e | ||
|
|
c44f085b47 | ||
|
|
a35f36eeaf | ||
|
|
14564c392a | ||
|
|
76e05ea749 | ||
|
|
ab599dceed | ||
|
|
4c37604445 | ||
|
|
bb74018d19 | ||
|
|
575289e5bc | ||
|
|
e89da2a7b4 | ||
|
|
d2f7e55bf5 | ||
|
|
9f31df7f3a | ||
|
|
e1bed60f1f | ||
|
|
edbb856023 | ||
|
|
98d3ab646f | ||
|
|
ab677ea100 | ||
|
|
28a87351f1 | ||
|
|
dcd7dcbbdf | ||
|
|
1538759ba7 | ||
|
|
ec5d71d0e1 | ||
|
|
d121d08d05 | ||
|
|
be08f4a558 | ||
|
|
4df8606ab6 | ||
|
|
71442d26ec | ||
|
|
4f5528869c | ||
|
|
f16feff17b | ||
|
|
d8aae538cd | ||
|
|
31670e75e5 | ||
|
|
ed6011a2be | ||
|
|
cdded38ade | ||
|
|
f536f24833 | ||
|
|
646b18d910 | ||
|
|
e24225c828 | ||
|
|
50a296de20 | ||
|
|
c79e38e044 | ||
|
|
dae745d925 | ||
|
|
791db65526 | ||
|
|
02e2e617f5 | ||
|
|
bfc8024119 | ||
|
|
f26cf6ed6f | ||
|
|
f2be55bd8e | ||
|
|
d241dd17ca | ||
|
|
cecafdfe6c | ||
|
|
6fecfd1a0e |
31
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md
vendored
31
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md
vendored
@@ -1,31 +0,0 @@
|
||||
---
|
||||
name: '🥳 发布插件'
|
||||
title: "[Plugin] 插件名"
|
||||
about: 提交插件到插件市场
|
||||
labels: [ "plugin-publish" ]
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
欢迎发布插件到插件市场!
|
||||
|
||||
## 插件基本信息
|
||||
|
||||
请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "插件名",
|
||||
"desc": "插件介绍",
|
||||
"author": "作者名",
|
||||
"repo": "插件仓库链接",
|
||||
"tags": [],
|
||||
"social_link": ""
|
||||
}
|
||||
```
|
||||
|
||||
## 检查
|
||||
|
||||
- [ ] 我的插件经过完整的测试
|
||||
- [ ] 我的插件不包含恶意代码
|
||||
- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
56
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
Normal file
56
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: 🥳 发布插件
|
||||
description: 提交插件到插件市场
|
||||
title: "[Plugin] 插件名"
|
||||
labels: ["plugin-publish"]
|
||||
assignees: []
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
欢迎发布插件到插件市场!
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## 插件基本信息
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
attributes:
|
||||
label: 插件信息
|
||||
description: 请在下方代码块中填写您的插件信息,确保反引号包裹了JSON
|
||||
value: |
|
||||
```json
|
||||
{
|
||||
"name": "插件名",
|
||||
"desc": "插件介绍",
|
||||
"author": "作者名",
|
||||
"repo": "插件仓库链接",
|
||||
"tags": [],
|
||||
"social_link": ""
|
||||
}
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## 检查
|
||||
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
label: 插件检查清单
|
||||
description: 请确认以下所有项目
|
||||
options:
|
||||
- label: 我的插件经过完整的测试
|
||||
required: true
|
||||
- label: 我的插件不包含恶意代码
|
||||
required: true
|
||||
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
63
.github/copilot-instructions.md
vendored
Normal file
63
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# AstrBot Development Instructions
|
||||
|
||||
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
|
||||
|
||||
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
|
||||
|
||||
## Working Effectively
|
||||
|
||||
### Bootstrap and Install Dependencies
|
||||
- **Python 3.10+ required** - Check `.python-version` file
|
||||
- Install UV package manager: `pip install uv`
|
||||
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
|
||||
- Create required directories: `mkdir -p data/plugins data/config data/temp`
|
||||
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
- Navigate to dashboard: `cd dashboard`
|
||||
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
|
||||
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
|
||||
- Dashboard creates optimized production build in `dashboard/dist/`
|
||||
|
||||
### Testing
|
||||
- Do not generate test files for now.
|
||||
|
||||
### Code Quality and Linting
|
||||
- Install ruff linter: `uv add --dev ruff`
|
||||
- Check code style: `uv run ruff check .` -- takes <1 second
|
||||
- Check formatting: `uv run ruff format --check .` -- takes <1 second
|
||||
- Fix formatting: `uv run ruff format .`
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
### Common Issues and Workarounds
|
||||
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
|
||||
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
|
||||
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
|
||||
|
||||
## CI/CD Integration
|
||||
- GitHub Actions workflows in `.github/workflows/`
|
||||
- Docker builds supported via `Dockerfile`
|
||||
- Pre-commit hooks enforce ruff formatting and linting
|
||||
|
||||
## Docker Support
|
||||
- Primary deployment method: `docker run soulter/astrbot:latest`
|
||||
- Compose file available: `compose.yml`
|
||||
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
|
||||
- Volume mount required: `./data:/AstrBot/data`
|
||||
|
||||
## Multi-language Support
|
||||
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
|
||||
- UI supports internationalization
|
||||
- Default language is Chinese
|
||||
|
||||
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
|
||||
4
.github/workflows/auto_release.yml
vendored
4
.github/workflows/auto_release.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
|
||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
|
||||
20
.github/workflows/coverage_test.yml
vendored
20
.github/workflows/coverage_test.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
@@ -8,6 +8,7 @@ on:
|
||||
- 'README.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -16,7 +17,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -26,20 +27,19 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
pip install --editable .
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir data/temp
|
||||
mkdir -p data/plugins
|
||||
mkdir -p data/config
|
||||
mkdir -p data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
||||
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
2
.github/workflows/dashboard_ci.yml
vendored
2
.github/workflows/dashboard_ci.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: npm install, build
|
||||
run: |
|
||||
|
||||
2
.github/workflows/docker-image.yml
vendored
2
.github/workflows/docker-image.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Pull The Codes
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
|
||||
|
||||
96
README.md
96
README.md
@@ -27,57 +27,50 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
>
|
||||
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
<details><summary>1. AstrBot 现已自带知识库能力</summary>
|
||||
|
||||
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
> [!NOTE]
|
||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
> [!TIP]
|
||||
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
|
||||
## ✨ 使用方式
|
||||
|
||||
#### Docker 部署
|
||||
|
||||
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||
|
||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||
|
||||
#### 1Panel 部署
|
||||
|
||||
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||
|
||||
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
|
||||
|
||||
#### 在 雨云 上部署
|
||||
|
||||
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
@@ -101,31 +94,14 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者,直接通过 uvx 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
mkdir astrbot && cd astrbot
|
||||
uvx astrbot init
|
||||
# uvx astrbot run
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### 在 雨云 上部署
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| 微信个人号 | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
@@ -144,7 +120,7 @@ uvx astrbot init
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
@@ -152,6 +128,8 @@ uvx astrbot init
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
@@ -244,11 +222,5 @@ _✨ WebUI ✨_
|
||||

|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||
3. Please ensure compliance with local laws and regulations when using this project.
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
15
README_ja.md
15
README_ja.md
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
## ✨ 主な機能
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
@@ -35,7 +35,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
|
||||
> [!TIP]
|
||||
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
>
|
||||
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||
|
||||
## ✨ 使用方法
|
||||
@@ -136,11 +136,11 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> [!TIP]
|
||||
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
||||
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
0
astrbot.lock
Normal file
0
astrbot.lock
Normal file
@@ -1 +1 @@
|
||||
__version__ = "3.5.8"
|
||||
__version__ = "3.5.23"
|
||||
|
||||
@@ -139,6 +139,14 @@ def conf():
|
||||
- dashboard.password: Dashboard 密码
|
||||
|
||||
- callback_api_base: 回调接口基址
|
||||
|
||||
可用子命令:
|
||||
|
||||
- set: 设置配置项值
|
||||
|
||||
- get: 获取配置项值
|
||||
|
||||
- login-info: 显示 Web 管理面板登录信息
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -204,3 +212,44 @@ def get_config(key: str = None):
|
||||
click.echo(f" {key}: {value}")
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="login-info")
|
||||
def get_login_info():
|
||||
"""显示 Web 管理面板的登录信息
|
||||
|
||||
在 Docker 环境中使用示例:
|
||||
docker exec -e ASTRBOT_ROOT=/AstrBot astrbot-container astrbot conf login-info
|
||||
"""
|
||||
config = _load_config()
|
||||
|
||||
try:
|
||||
username = _get_nested_item(config, "dashboard.username")
|
||||
# 注意:我们不显示实际的MD5哈希密码,而是提示用户如何重置
|
||||
click.echo("🔐 Web 管理面板登录信息:")
|
||||
click.echo(f" 用户名: {username}")
|
||||
click.echo(" 密码: [已加密存储]")
|
||||
click.echo()
|
||||
click.echo("💡 如需重置密码,请使用以下命令:")
|
||||
click.echo(" astrbot conf set dashboard.password <新密码>")
|
||||
click.echo()
|
||||
click.echo("🌐 访问地址:")
|
||||
|
||||
# 尝试获取端口信息
|
||||
try:
|
||||
port = _get_nested_item(config, "dashboard.port")
|
||||
click.echo(f" http://localhost:{port}")
|
||||
click.echo(f" http://your-server-ip:{port}")
|
||||
except (KeyError, TypeError):
|
||||
click.echo(" http://localhost:6185 (默认端口)")
|
||||
click.echo(" http://your-server-ip:6185 (默认端口)")
|
||||
|
||||
click.echo()
|
||||
click.echo("📋 Docker 环境使用说明:")
|
||||
click.echo(" 如果在 Docker 中运行,请使用以下命令格式:")
|
||||
click.echo(" docker exec -e ASTRBOT_ROOT=/AstrBot <容器名> astrbot conf login-info")
|
||||
|
||||
except KeyError:
|
||||
click.echo("❌ 无法找到登录配置,请先运行 'astrbot init' 初始化")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"获取登录信息失败: {str(e)}")
|
||||
|
||||
@@ -16,7 +16,13 @@ def check_astrbot_root(path: str | Path) -> bool:
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""获取Astrbot根目录路径"""
|
||||
return Path.cwd()
|
||||
import os
|
||||
|
||||
# 使用与core应用相同的路径解析逻辑,优先使用ASTRBOT_ROOT环境变量
|
||||
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||
return Path(path)
|
||||
else:
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
|
||||
@@ -117,19 +117,24 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
if "desc" not in metadata and "description" in metadata:
|
||||
metadata["desc"] = metadata["description"]
|
||||
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
}
|
||||
)
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
@@ -139,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
online_plugins.append(
|
||||
{
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.19"
|
||||
VERSION = "3.5.24"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -65,6 +65,7 @@ DEFAULT_CONFIG = {
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": True,
|
||||
"max_agent_step": 30,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -157,15 +158,6 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"微信个人号(Gewechat)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
"微信个人号(WeChatPadPro)": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
@@ -318,8 +310,7 @@ CONFIG_METADATA_2 = {
|
||||
"id": {
|
||||
"description": "机器人名称",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
|
||||
"hint": "机器人名称",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -370,7 +361,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"discord_token": {
|
||||
"description": "Discord Bot Token",
|
||||
@@ -486,13 +476,11 @@ CONFIG_METADATA_2 = {
|
||||
"regex": {
|
||||
"description": "正则表达式",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||
},
|
||||
"content_cleanup_rule": {
|
||||
"description": "过滤分段后的内容",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
|
||||
},
|
||||
},
|
||||
@@ -515,7 +503,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
@@ -545,7 +532,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
},
|
||||
},
|
||||
@@ -605,18 +591,19 @@ CONFIG_METADATA_2 = {
|
||||
"config_template": {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
"provider": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"hint": "也兼容所有与OpenAI API兼容的服务。",
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -624,24 +611,23 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "grok-2-latest",
|
||||
},
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
},
|
||||
"Anthropic": {
|
||||
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
|
||||
"id": "claude",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -651,21 +637,23 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
},
|
||||
"Ollama": {
|
||||
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
|
||||
"id": "ollama_default",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -677,6 +665,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -685,10 +674,12 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
},
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -697,6 +688,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
@@ -714,30 +706,29 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"provider": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek-chat",
|
||||
},
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
"provider": "302ai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4.1-mini",
|
||||
},
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
},
|
||||
"硅基流动": {
|
||||
"id": "siliconflow",
|
||||
"provider": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -746,10 +737,12 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"id": "ppio",
|
||||
"provider": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -758,22 +751,36 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
},
|
||||
"优云智算": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
},
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {
|
||||
"model": "moonshot-v1-8k",
|
||||
},
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
},
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -786,6 +793,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
"provider": "dify",
|
||||
"type": "dify",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -796,9 +804,11 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -813,8 +823,20 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"provider": "fastgpt",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
@@ -824,6 +846,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Whisper(API)": {
|
||||
"id": "whisper",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_api",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
@@ -832,16 +855,18 @@ CONFIG_METADATA_2 = {
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"Whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"id": "whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(本地加载)": {
|
||||
"sensevoice_hint": "(不用修改我)",
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "sensevoice",
|
||||
@@ -851,6 +876,7 @@ CONFIG_METADATA_2 = {
|
||||
"OpenAI TTS(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"provider": "openai",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
@@ -860,8 +886,9 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge TTS": {
|
||||
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"provider": "microsoft",
|
||||
"type": "edge_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
@@ -871,6 +898,7 @@ CONFIG_METADATA_2 = {
|
||||
"GSV TTS(本地加载)": {
|
||||
"id": "gsv_tts",
|
||||
"enable": False,
|
||||
"provider": "gpt_sovits",
|
||||
"type": "gsv_tts_selfhost",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:9880",
|
||||
@@ -902,6 +930,7 @@ CONFIG_METADATA_2 = {
|
||||
"GSVI TTS(API)": {
|
||||
"id": "gsvi_tts",
|
||||
"type": "gsvi_tts_api",
|
||||
"provider": "gpt_sovits_inference",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:5000",
|
||||
"character": "",
|
||||
@@ -911,6 +940,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"FishAudio TTS(API)": {
|
||||
"id": "fishaudio_tts",
|
||||
"provider": "fishaudio",
|
||||
"type": "fishaudio_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
@@ -921,6 +951,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"id": "dashscope_tts",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
@@ -932,6 +963,7 @@ CONFIG_METADATA_2 = {
|
||||
"Azure TTS": {
|
||||
"id": "azure_tts",
|
||||
"type": "azure_tts",
|
||||
"provider": "azure",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": True,
|
||||
"azure_tts_voice": "zh-CN-YunxiaNeural",
|
||||
@@ -945,6 +977,7 @@ CONFIG_METADATA_2 = {
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
"type": "minimax_tts_api",
|
||||
"provider": "minimax",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
@@ -966,6 +999,7 @@ CONFIG_METADATA_2 = {
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
"type": "volcengine_tts",
|
||||
"provider": "volcengine",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
@@ -979,6 +1013,7 @@ CONFIG_METADATA_2 = {
|
||||
"Gemini TTS": {
|
||||
"id": "gemini_tts",
|
||||
"type": "gemini_tts",
|
||||
"provider": "google",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"gemini_tts_api_key": "",
|
||||
@@ -991,17 +1026,19 @@ CONFIG_METADATA_2 = {
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
"provider": "openai",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1536,
|
||||
"embedding_dimensions": 1024,
|
||||
"timeout": 20,
|
||||
},
|
||||
"Gemini Embedding": {
|
||||
"id": "gemini_embedding",
|
||||
"type": "gemini_embedding",
|
||||
"provider": "google",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
@@ -1012,17 +1049,19 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
},
|
||||
"gpt_weights_path": {
|
||||
"description": "GPT模型文件路径",
|
||||
"type": "string",
|
||||
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"sovits_weights_path": {
|
||||
"description": "SoVITS模型文件路径",
|
||||
"type": "string",
|
||||
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_default_parms": {
|
||||
"description": "GPT_SoVITS默认参数",
|
||||
@@ -1033,13 +1072,11 @@ CONFIG_METADATA_2 = {
|
||||
"description": "参考音频文件路径",
|
||||
"type": "string",
|
||||
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_prompt_text": {
|
||||
"description": "参考音频文本",
|
||||
"type": "string",
|
||||
"hint": "必填!请填写参考音频讲述的文本",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gsv_prompt_lang": {
|
||||
"description": "参考音频文本语言",
|
||||
@@ -1266,19 +1303,16 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_native_coderunner": {
|
||||
"description": "启用原生代码执行器",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_url_context": {
|
||||
"description": "启用URL上下文功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
@@ -1462,7 +1496,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "部署SenseVoice",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"is_emotion": {
|
||||
"description": "情绪识别",
|
||||
@@ -1477,18 +1510,10 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {
|
||||
"description": "工作流固定输入变量",
|
||||
"type": "object",
|
||||
"obvious_hint": True,
|
||||
"items": {},
|
||||
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
||||
"invisible": True,
|
||||
},
|
||||
# "fastgpt_app_type": {
|
||||
# "description": "应用类型",
|
||||
# "type": "string",
|
||||
# "hint": "FastGPT 应用的应用类型。",
|
||||
# "options": ["agent", "workflow", "plugin"],
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
"dashscope_app_type": {
|
||||
"description": "应用类型",
|
||||
"type": "string",
|
||||
@@ -1499,7 +1524,6 @@ CONFIG_METADATA_2 = {
|
||||
"dialog-workflow",
|
||||
"task-workflow",
|
||||
],
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"timeout": {
|
||||
"description": "超时时间",
|
||||
@@ -1509,26 +1533,22 @@ CONFIG_METADATA_2 = {
|
||||
"openai-tts-voice": {
|
||||
"description": "voice",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||
},
|
||||
"fishaudio-tts-character": {
|
||||
"description": "character",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
||||
},
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
|
||||
"hint": "模型提供商名字。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商种类",
|
||||
@@ -1543,53 +1563,27 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用",
|
||||
"type": "bool",
|
||||
"hint": "是否启用该模型。未启用的模型将不会被使用。",
|
||||
"hint": "是否启用。",
|
||||
},
|
||||
"key": {
|
||||
"description": "API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。",
|
||||
"hint": "提供商 API Key。",
|
||||
},
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
"description": "基座模型路径",
|
||||
"type": "string",
|
||||
"hint": "基座模型路径。",
|
||||
},
|
||||
"adapter_model_path": {
|
||||
"description": "Adapter 模型路径",
|
||||
"type": "string",
|
||||
"hint": "Adapter 模型路径。如 Lora",
|
||||
},
|
||||
"llmtuner_template": {
|
||||
"description": "template",
|
||||
"type": "string",
|
||||
"hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。",
|
||||
},
|
||||
"finetuning_type": {
|
||||
"description": "微调类型",
|
||||
"type": "string",
|
||||
"hint": "微调类型。如 `lora`",
|
||||
},
|
||||
"quantization_bit": {
|
||||
"description": "量化位数",
|
||||
"type": "int",
|
||||
"hint": "量化位数。如 4",
|
||||
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "文本生成模型",
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {
|
||||
"description": "模型名称",
|
||||
"type": "string",
|
||||
"hint": "大语言模型的名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat 等。",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "模型最大输出长度(tokens)",
|
||||
@@ -1636,7 +1630,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"separate_provider": {
|
||||
"description": "提供商会话隔离",
|
||||
@@ -1656,13 +1649,11 @@ CONFIG_METADATA_2 = {
|
||||
"web_search": {
|
||||
"description": "启用网页搜索",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
},
|
||||
"web_search_link": {
|
||||
"description": "网页搜索引用链接",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||
},
|
||||
"display_reasoning_text": {
|
||||
@@ -1673,13 +1664,11 @@ CONFIG_METADATA_2 = {
|
||||
"identifier": {
|
||||
"description": "启动识别群员",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"description": "启用日期时间系统提示",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||
},
|
||||
"default_personality": {
|
||||
@@ -1717,6 +1706,10 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
"max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
@@ -1736,7 +1729,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "人格名称",
|
||||
"type": "string",
|
||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"prompt": {
|
||||
"description": "设定(系统提示词)",
|
||||
@@ -1748,14 +1740,12 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1767,7 +1757,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用语音转文本(STT)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID",
|
||||
@@ -1784,7 +1773,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用文本转语音(TTS)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID",
|
||||
@@ -1795,7 +1783,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用语音和文字双输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"use_file_service": {
|
||||
"description": "使用文件服务提供 TTS 语音文件",
|
||||
@@ -1811,25 +1798,21 @@ CONFIG_METADATA_2 = {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious_hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "群聊图像转述(需模型支持)",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"description": "图像转述提供商 ID",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
@@ -1843,14 +1826,12 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用主动回复",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
|
||||
},
|
||||
"whitelist": {
|
||||
"description": "主动回复白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
|
||||
},
|
||||
"method": {
|
||||
@@ -1862,13 +1843,11 @@ CONFIG_METADATA_2 = {
|
||||
"possibility_reply": {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"obvious_hint": True,
|
||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
@@ -1884,7 +1863,6 @@ CONFIG_METADATA_2 = {
|
||||
"description": "机器人唤醒前缀",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
||||
},
|
||||
"t2i": {
|
||||
@@ -1911,13 +1889,11 @@ CONFIG_METADATA_2 = {
|
||||
"timezone": {
|
||||
"description": "时区",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||
},
|
||||
"callback_api_base": {
|
||||
"description": "对外可达的回调接口地址",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
||||
},
|
||||
"log_level": {
|
||||
@@ -1965,90 +1941,3 @@ DEFAULT_VALUE_MAP = {
|
||||
"list": [],
|
||||
"object": {},
|
||||
}
|
||||
|
||||
|
||||
# "project_atri": {
|
||||
# "description": "Project ATRI 配置",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# "long_term_memory": {
|
||||
# "description": "长期记忆",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# "summary_threshold_cnt": {
|
||||
# "description": "摘要阈值",
|
||||
# "type": "int",
|
||||
# "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。",
|
||||
# },
|
||||
# "embedding_provider_id": {
|
||||
# "description": "Embedding provider ID",
|
||||
# "type": "string",
|
||||
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "summarize_provider_id": {
|
||||
# "description": "Summary provider ID",
|
||||
# "type": "string",
|
||||
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
# "active_message": {
|
||||
# "description": "主动消息",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# },
|
||||
# },
|
||||
# "vision": {
|
||||
# "description": "视觉理解",
|
||||
# "type": "object",
|
||||
# "items": {
|
||||
# "enable": {"description": "启用", "type": "bool"},
|
||||
# "provider_id_or_ofa_model_path": {
|
||||
# "description": "提供商 ID 或 OFA 模型路径",
|
||||
# "type": "string",
|
||||
# "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。",
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
# "split_response": {
|
||||
# "description": "是否分割回复",
|
||||
# "type": "bool",
|
||||
# "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。",
|
||||
# },
|
||||
# "persona": {
|
||||
# "description": "人格",
|
||||
# "type": "string",
|
||||
# "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "chat_provider_id": {
|
||||
# "description": "Chat provider ID",
|
||||
# "type": "string",
|
||||
# "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "chat_base_model_path": {
|
||||
# "description": "用于聊天的基座模型路径",
|
||||
# "type": "string",
|
||||
# "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "chat_adapter_model_path": {
|
||||
# "description": "用于聊天的 Lora 模型路径",
|
||||
# "type": "string",
|
||||
# "hint": "Lora 模型路径。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# "quantization_bit": {
|
||||
# "description": "量化位数",
|
||||
# "type": "int",
|
||||
# "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。",
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
# },
|
||||
# },
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
__all__ = ["FaissVecDB"]
|
||||
|
||||
@@ -2,6 +2,8 @@ import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from urllib.parse import urlparse, unquote
|
||||
import platform
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
@@ -15,7 +17,9 @@ class FileTokenService:
|
||||
async def _cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
||||
expired_tokens = [
|
||||
token for token, (_, expire) in self.staged_files.items() if expire < now
|
||||
]
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
@@ -32,15 +36,35 @@ class FileTokenService:
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
|
||||
# 处理 file:///
|
||||
try:
|
||||
parsed_uri = urlparse(file_path)
|
||||
if parsed_uri.scheme == "file":
|
||||
local_path = unquote(parsed_uri.path)
|
||||
if platform.system() == "Windows" and local_path.startswith("/"):
|
||||
local_path = local_path[1:]
|
||||
else:
|
||||
# 如果没有 file:/// 前缀,则认为是普通路径
|
||||
local_path = file_path
|
||||
except Exception:
|
||||
# 解析失败时,按原路径处理
|
||||
local_path = file_path
|
||||
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
if not os.path.exists(local_path):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在: {local_path} (原始输入: {file_path})"
|
||||
)
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
||||
self.staged_files[file_token] = (file_path, expire_time)
|
||||
expire_time = time.time() + (
|
||||
timeout if timeout is not None else self.default_timeout
|
||||
)
|
||||
# 存储转换后的真实路径
|
||||
self.staged_files[file_token] = (local_path, expire_time)
|
||||
return file_token
|
||||
|
||||
async def handle_file(self, file_token: str) -> str:
|
||||
|
||||
@@ -96,8 +96,6 @@ class LogBroker:
|
||||
Queue: 订阅者的队列, 可用于接收日志消息
|
||||
"""
|
||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||
for log in self.log_cache:
|
||||
q.put_nowait(log)
|
||||
self.subscribers.append(q)
|
||||
return q
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ class Plain(BaseMessageComponent):
|
||||
async def to_dict(self):
|
||||
return {"type": "text", "data": {"text": self.text}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
id: int
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
EventResultType,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
@@ -29,6 +31,7 @@ STAGES_ORDER = [
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
|
||||
@@ -8,10 +8,11 @@ from enum import Enum, auto
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Agent 状态枚举"""
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
|
||||
@@ -2,29 +2,30 @@
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import copy
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from ..stage import Stage
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -72,6 +73,12 @@ class LLMRequestSubStage(Stage):
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
@@ -166,6 +173,9 @@ class LLMRequestSubStage(Stage):
|
||||
event=event,
|
||||
pipeline_ctx=self.ctx,
|
||||
)
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
||||
)
|
||||
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||
|
||||
async def requesting():
|
||||
@@ -184,7 +194,8 @@ class LLMRequestSubStage(Stage):
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
if resp.type == "tool_call":
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if self.streaming_response:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
@@ -220,7 +231,7 @@ class LLMRequestSubStage(Stage):
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -177,25 +178,26 @@ class RespondStage(Stage):
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
|
||||
@@ -3,11 +3,12 @@ import time
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
@@ -176,10 +177,12 @@ class ResultDecorateStage(Stage):
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
|
||||
@@ -73,7 +73,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||
if event.get_platform_name() == "webchat":
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
event.stop_event()
|
||||
@@ -1,13 +1,16 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -166,6 +169,11 @@ class WakingCheckStage(Stage):
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
event, activated_handlers
|
||||
)
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
||||
|
||||
|
||||
@@ -227,7 +227,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp, gewechat。
|
||||
Fallback仅支持 aiocqhttp。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
@@ -419,7 +419,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
适配情况:
|
||||
|
||||
- gewechat
|
||||
- aiocqhttp(OneBotv11)
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -58,10 +58,6 @@ class PlatformManager:
|
||||
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
@@ -58,50 +58,85 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
@classmethod
|
||||
async def _dispatch_send(
|
||||
cls,
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str,
|
||||
messages: list[dict],
|
||||
):
|
||||
if event:
|
||||
await bot.send(event=event, message=messages)
|
||||
elif is_group:
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
else:
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
|
||||
@classmethod
|
||||
async def send_message(
|
||||
cls,
|
||||
bot: CQHttp,
|
||||
message_chain: MessageChain,
|
||||
event: Event | None = None,
|
||||
is_group: bool = False,
|
||||
session_id: str = None,
|
||||
):
|
||||
"""发送消息"""
|
||||
|
||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||
send_one_by_one = any(
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
|
||||
)
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 合并转发消息
|
||||
|
||||
if isinstance(seg, Node):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
else:
|
||||
payload["user_id"] = self.get_sender_id()
|
||||
await self.bot.call_action(
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
elif isinstance(seg, File):
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
[d],
|
||||
)
|
||||
else:
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
await AiocqhttpMessageEvent._parse_onebot_json(
|
||||
MessageChain([seg])
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if not send_one_by_one:
|
||||
ret = await cls._parse_onebot_json(message_chain)
|
||||
if not ret:
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, ret)
|
||||
return
|
||||
for seg in message_chain.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 合并转发消息
|
||||
if isinstance(seg, Node):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if is_group:
|
||||
payload["group_id"] = session_id
|
||||
await bot.call_action("send_group_forward_msg", **payload)
|
||||
else:
|
||||
payload["user_id"] = session_id
|
||||
await bot.call_action("send_private_forward_msg", **payload)
|
||||
elif isinstance(seg, File):
|
||||
d = await cls._from_segment_to_dict(seg)
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, [d])
|
||||
else:
|
||||
messages = await cls._parse_onebot_json(MessageChain([seg]))
|
||||
if not messages:
|
||||
continue
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, messages)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息"""
|
||||
event = self.message_obj.raw_message
|
||||
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
|
||||
is_group = False
|
||||
if self.get_group_id():
|
||||
is_group = True
|
||||
session_id = self.get_group_id()
|
||||
else:
|
||||
session_id = self.get_sender_id()
|
||||
await self.send_message(
|
||||
bot=self.bot,
|
||||
message_chain=message,
|
||||
event=event,
|
||||
is_group=is_group,
|
||||
session_id=session_id,
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
|
||||
@@ -83,19 +83,18 @@ class AiocqhttpAdapter(Platform):
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
||||
match session.message_type.value:
|
||||
case MessageType.GROUP_MESSAGE.value:
|
||||
if "_" in session.session_id:
|
||||
# 独立会话
|
||||
_, group_id = session.session_id.split("_")
|
||||
await self.bot.send_group_msg(group_id=group_id, message=ret)
|
||||
else:
|
||||
await self.bot.send_group_msg(
|
||||
group_id=session.session_id, message=ret
|
||||
)
|
||||
case MessageType.FRIEND_MESSAGE.value:
|
||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||
is_group = session.message_type == MessageType.GROUP_MESSAGE
|
||||
if is_group:
|
||||
session_id = session.session_id.split("_")[-1]
|
||||
else:
|
||||
session_id = session.session_id
|
||||
await AiocqhttpMessageEvent.send_message(
|
||||
bot=self.bot,
|
||||
message_chain=message_chain,
|
||||
event=None, # 这里不需要 event,因为是通过 session 发送的
|
||||
is_group=is_group,
|
||||
session_id=session_id,
|
||||
)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
@@ -273,8 +272,14 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
|
||||
reply_event_data["post_type"] = "message"
|
||||
new_event = Event.from_payload(reply_event_data)
|
||||
if not new_event:
|
||||
logger.error(
|
||||
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
|
||||
)
|
||||
continue
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
new_event, get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
@@ -307,7 +312,9 @@ class AiocqhttpAdapter(Platform):
|
||||
user_id=int(m["data"]["qq"]),
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
|
||||
nickname = at_info.get("nick", "") or at_info.get(
|
||||
"nickname", ""
|
||||
)
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
|
||||
@@ -57,6 +57,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"钉钉图片处理失败: {e}")
|
||||
logger.warning(f"跳过图片发送: {image_path}")
|
||||
continue
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
|
||||
await self.on_ready_once_callback()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
|
||||
)
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
|
||||
message_data = self._create_message_data(message)
|
||||
await self.on_message_received(message_data)
|
||||
|
||||
|
||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||
"""从交互中提取内容"""
|
||||
interaction_type = interaction.type
|
||||
|
||||
@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
|
||||
|
||||
class DiscordReference(BaseMessageComponent):
|
||||
"""Discord引用组件"""
|
||||
|
||||
type: str = "discord_reference"
|
||||
|
||||
def __init__(self, message_id: str, channel_id: str):
|
||||
self.message_id = message_id
|
||||
self.channel_id = channel_id
|
||||
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def to_discord_view(self) -> discord.ui.View:
|
||||
"""转换为Discord View对象"""
|
||||
view = discord.ui.View(timeout=self.timeout)
|
||||
|
||||
@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
try:
|
||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
||||
(
|
||||
content,
|
||||
files,
|
||||
view,
|
||||
embeds,
|
||||
reference_message_id,
|
||||
) = await self._parse_to_discord(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||
return
|
||||
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
files.append(
|
||||
discord.File(BytesIO(file_bytes),
|
||||
filename=i.name)
|
||||
discord.File(BytesIO(file_bytes), filename=i.name)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -1,812 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SimpleGewechatClient:
|
||||
"""针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
nickname: str,
|
||||
host: str,
|
||||
port: int,
|
||||
event_queue: asyncio.Queue,
|
||||
):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith("/"):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
|
||||
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_token>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
self.userrealnames = {}
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self.staged_files = {}
|
||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob["data"]
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {"X-GEWE-TOKEN": self.token}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
if "TypeName" in data:
|
||||
type_name = data["TypeName"]
|
||||
elif "type_name" in data:
|
||||
type_name = data["type_name"]
|
||||
else:
|
||||
raise Exception("无法识别的消息类型")
|
||||
|
||||
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||
if type_name == "ModContacts":
|
||||
logger.info("gewechat下发:ModContacts消息通知。")
|
||||
return
|
||||
if type_name == "DelContacts":
|
||||
logger.info("gewechat下发:DelContacts消息通知。")
|
||||
return
|
||||
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
d = None
|
||||
if "Data" in data:
|
||||
d = data["Data"]
|
||||
elif "data" in data:
|
||||
d = data["data"]
|
||||
|
||||
if not d:
|
||||
logger.warning(f"消息不含 data 字段: {data}")
|
||||
return
|
||||
|
||||
if "CreateTime" in d:
|
||||
# 得到系统 UTF+8 的 ts
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
ts = datetime.datetime.now(tz).timestamp()
|
||||
create_time = d["CreateTime"]
|
||||
if create_time < ts - 30:
|
||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
from_user_name = d["FromUserName"]["string"] # 消息来源
|
||||
d["to_wxid"] = from_user_name # 用于发信息
|
||||
|
||||
abm.message_id = str(d.get("MsgId"))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data["Wxid"] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
at_wxids = []
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
at_wxids = re.findall(
|
||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||
msg_source,
|
||||
)
|
||||
|
||||
abm.group_id = from_user_name
|
||||
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
):
|
||||
at_me = True
|
||||
if "在群聊中@了你" in d.get("PushContent", ""):
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||
# if user_id == abm.self_id:
|
||||
# logger.info("忽略自己发送的消息")
|
||||
# return None
|
||||
|
||||
abm.message = []
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
if abm.group_id:
|
||||
if (
|
||||
abm.group_id not in self.userrealnames
|
||||
or user_id not in self.userrealnames[abm.group_id]
|
||||
):
|
||||
# 获取群成员列表,并且缓存
|
||||
if abm.group_id not in self.userrealnames:
|
||||
self.userrealnames[abm.group_id] = {}
|
||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||
if member_list and "memberList" in member_list:
|
||||
for member in member_list["memberList"]:
|
||||
self.userrealnames[abm.group_id][member["wxid"]] = member[
|
||||
"nickName"
|
||||
]
|
||||
if user_id in self.userrealnames[abm.group_id]:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
try:
|
||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||
user_real_name = info["nickName"]
|
||||
except Exception as e:
|
||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||
user_real_name = user_id
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||
for wxid in at_wxids:
|
||||
# 群聊里 At 其他人的列表
|
||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||
abm.message.append(At(qq=wxid, name=_username))
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
|
||||
if user_id == "weixin":
|
||||
# 忽略微信团队消息
|
||||
return
|
||||
|
||||
# 不同消息类型
|
||||
match d["MsgType"]:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid, content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||
case 37: # 好友申请
|
||||
logger.info("消息类型(37):好友申请")
|
||||
case 42: # 名片
|
||||
logger.info("消息类型(42):名片")
|
||||
case 43: # 视频
|
||||
video = Video(file="", cover=content)
|
||||
abm.message.append(video)
|
||||
case 47: # emoji
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
emoji = data_parser.parse_emoji()
|
||||
abm.message.append(emoji)
|
||||
case 48: # 地理位置
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
segments = data_parser.parse_mutil_49()
|
||||
if segments:
|
||||
abm.message.extend(segments)
|
||||
for seg in segments:
|
||||
if isinstance(seg, Plain):
|
||||
abm.message_str += seg.text
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||
logger.info(
|
||||
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||
)
|
||||
|
||||
case _:
|
||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||
abm.raw_message = d
|
||||
|
||||
logger.debug(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def _callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get("testMsg", None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = None
|
||||
try:
|
||||
abm = await self._convert(data)
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。"
|
||||
)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _register_file(self, file_path: str) -> str:
|
||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
Returns:
|
||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||
"""
|
||||
async with self.lock:
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
self.staged_files[file_token] = file_path
|
||||
return file_token
|
||||
|
||||
async def _handle_file(self, file_token):
|
||||
async with self.lock:
|
||||
if file_token not in self.staged_files:
|
||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||
return quart.abort(404)
|
||||
if not os.path.exists(self.staged_files[file_token]):
|
||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||
return quart.abort(404)
|
||||
file_path = self.staged_files[file_token]
|
||||
self.staged_files.pop(file_token, None)
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={"token": self.token, "callbackUrl": self.callback_url},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob["ret"] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(
|
||||
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
|
||||
)
|
||||
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host="0.0.0.0",
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
"""检查 APPID 对应的设备是否在线。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={"appId": appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def logout(self):
|
||||
"""登出 gewechat。"""
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={"appId": self.appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(
|
||||
self.base_url, self.download_base_url, self.token
|
||||
)
|
||||
|
||||
if self.appid:
|
||||
try:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态失败: {e}")
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob["ret"] != 200:
|
||||
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||
if "设备不存在" in error_msg:
|
||||
logger.error(
|
||||
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||
)
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
return await self.login()
|
||||
else:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob["data"]["qrData"]
|
||||
qr_uuid = json_blob["data"]["uuid"]
|
||||
appid = json_blob["data"]["appId"]
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(
|
||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({"uuid": qr_uuid, "appId": appid})
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
|
||||
ret = json_blob["ret"]
|
||||
msg = ""
|
||||
if json_blob["data"] and "msg" in json_blob["data"]:
|
||||
msg = json_blob["data"]["msg"]
|
||||
if ret == 500 and "安全验证码" in msg:
|
||||
logger.warning(
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
if "status" in json_blob["data"]:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||
"""
|
||||
|
||||
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||
"""获取群成员列表。
|
||||
|
||||
Args:
|
||||
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||
|
||||
Returns:
|
||||
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||
"""
|
||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
"""发送纯文本消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
if ats:
|
||||
payload["ats"] = ats
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
"""发送图片消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||
"""发送emoji消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"emojiMd5": emoji_md5,
|
||||
"emojiSize": emoji_size,
|
||||
}
|
||||
|
||||
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||
try:
|
||||
if emoji_md5 != "" and emoji_size != "":
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postEmoji",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(
|
||||
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||
)
|
||||
else:
|
||||
await self.post_image(to_wxid, cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
async def post_video(
|
||||
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"videoUrl": video_url,
|
||||
"thumbUrl": thumb_url,
|
||||
"videoDuration": video_duration,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送视频结果: {json_blob}")
|
||||
|
||||
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||
"""转发视频
|
||||
|
||||
Args:
|
||||
to_wxid (str): 发送给谁
|
||||
cnd_xml (str): 视频消息的cdn信息
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"xml": cnd_xml,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/forwardVideo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"转发视频结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
"""发送语音信息
|
||||
|
||||
Args:
|
||||
voice_url (str): 语音文件的网络链接
|
||||
voice_duration (int): 语音时长,毫秒
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration,
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||
|
||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||
"""发送文件
|
||||
|
||||
Args:
|
||||
to_wxid (string): 微信ID
|
||||
file_url (str): 文件的网络链接
|
||||
file_name (str): 文件名
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"fileUrl": file_url,
|
||||
"fileName": file_name,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送文件结果: {json_blob}")
|
||||
|
||||
async def add_friend(self, v3: str, v4: str, content: str):
|
||||
"""申请添加好友"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"scene": 3,
|
||||
"content": content,
|
||||
"v4": v4,
|
||||
"v3": v3,
|
||||
"option": 2,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/addContacts",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group_member(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def accept_group_invite(self, url: str):
|
||||
"""同意进群"""
|
||||
payload = {"appId": self.appid, "url": url}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/agreeJoinRoom",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def add_group_member_to_friend(
|
||||
self, group_id: str, to_wxid: str, content: str
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
"content": content,
|
||||
"memberWxid": to_wxid,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_user_or_group_info(self, *ids):
|
||||
"""
|
||||
获取用户或群组信息。
|
||||
|
||||
:param ids: 可变数量的 wxid 参数
|
||||
"""
|
||||
|
||||
wxids_str = list(ids)
|
||||
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/getDetailInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_contacts_list(self):
|
||||
"""
|
||||
获取通讯录列表
|
||||
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||
"""
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/fetchContactsList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||
return json_blob
|
||||
@@ -1,55 +0,0 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
class GeweDownloader:
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
"""返回一个可下载的 URL"""
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {"appId": appid, "xml": xml, "type": choice}
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadImage", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
if "fileUrl" in json_blob["data"]:
|
||||
return self.download_base_url + json_blob["data"]["fileUrl"]
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
|
||||
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||
"""下载emoji"""
|
||||
try:
|
||||
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||
|
||||
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadEmojiMd5", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
return json_blob
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download emoji: {e}")
|
||||
@@ -1,264 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
Record,
|
||||
At,
|
||||
File,
|
||||
Video,
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, "rb") as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
elif n_frames == 0:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(
|
||||
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
||||
):
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
# 检查@
|
||||
ats = []
|
||||
ats_names = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, At):
|
||||
ats.append(comp.qq)
|
||||
ats_names.append(comp.name)
|
||||
has_at = False
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text = comp.text
|
||||
payload = {
|
||||
"to_wxid": to_wxid,
|
||||
"content": text,
|
||||
}
|
||||
if not has_at and ats:
|
||||
ats = f"{','.join(ats)}"
|
||||
ats_names = f"@{' @'.join(ats_names)}"
|
||||
text = f"{ats_names} {text}"
|
||||
payload["content"] = text
|
||||
payload["ats"] = ats
|
||||
has_at = True
|
||||
await client.post_text(**payload)
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||
token = await client._register_file(img_path)
|
||||
img_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
if comp.cover != "":
|
||||
await client.forward_video(to_wxid, comp.cover)
|
||||
else:
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
raise ModuleNotFoundError(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
|
||||
video_token = await client._register_file(video_path)
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||
ff.options(command)
|
||||
thumb_token = await client._register_file(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
|
||||
# 创建 FFprobe 实例
|
||||
ffprobe = FFprobe(video_url)
|
||||
# 获取时长字符串
|
||||
duration_str = ffprobe.duration
|
||||
# 处理时长字符串
|
||||
video_duration = float(duration_str.replace(":", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
# 发送视频
|
||||
await client.post_video(
|
||||
to_wxid, video_callback_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时缩略图文件
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
token = await client._register_file(silk_path)
|
||||
record_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
file_path = comp.file
|
||||
file_name = comp.name
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
token = await client._register_file(file_path)
|
||||
file_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_name)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||
await super().send(message)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
# 确定有效的 group_id
|
||||
if group_id is None:
|
||||
group_id = self.get_group_id()
|
||||
|
||||
if not group_id:
|
||||
return None
|
||||
|
||||
res = await self.client.get_group(group_id)
|
||||
data: dict = res["data"]
|
||||
|
||||
if not data["chatroomId"]:
|
||||
return None
|
||||
|
||||
members = [
|
||||
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||
for member in data.get("memberList", [])
|
||||
]
|
||||
|
||||
return Group(
|
||||
group_id=data["chatroomId"],
|
||||
group_name=data.get("nickName"),
|
||||
group_avatar=data.get("smallHeadImgUrl"),
|
||||
group_owner=data.get("chatRoomOwner"),
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
@@ -1,103 +0,0 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||
self.client = None
|
||||
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config["base_url"],
|
||||
self.config["nickname"],
|
||||
self.config["host"],
|
||||
self.config["port"],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
session_id = session.session_id
|
||||
if "#" in session_id:
|
||||
# unique session
|
||||
to_wxid = session_id.split("#")[1]
|
||||
else:
|
||||
to_wxid = session_id
|
||||
|
||||
await GewechatPlatformEvent.send_with_client(
|
||||
message_chain, to_wxid, self.client
|
||||
)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="gewechat",
|
||||
description="基于 gewechat 的 Wechat 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
self.client.shutdown_event.set()
|
||||
try:
|
||||
await self.client.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
await self.client.start_polling()
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss["unique_session"]:
|
||||
message.session_id = message.sender.user_id + "#" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self) -> SimpleGewechatClient:
|
||||
return self.client
|
||||
@@ -1,110 +0,0 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Reply,
|
||||
Plain,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
def __init__(self, data, is_private_chat):
|
||||
self.data = data
|
||||
self.is_private_chat = is_private_chat
|
||||
|
||||
def _format_to_xml(self):
|
||||
return eT.fromstring(self.data)
|
||||
|
||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||
if appmsg_type is None:
|
||||
return
|
||||
|
||||
match appmsg_type.text:
|
||||
case "57":
|
||||
return self.parse_reply()
|
||||
|
||||
def parse_emoji(self) -> Emoji | None:
|
||||
try:
|
||||
emoji_element = self._format_to_xml().find(".//emoji")
|
||||
# 提取 md5 和 len 属性
|
||||
if emoji_element is not None:
|
||||
md5_value = emoji_element.get("md5")
|
||||
emoji_size = emoji_element.get("len")
|
||||
cdnurl = emoji_element.get("cdnurl")
|
||||
|
||||
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||
|
||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||
"""解析引用消息
|
||||
|
||||
Returns:
|
||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||
"""
|
||||
try:
|
||||
replied_id = -1
|
||||
replied_uid = 0
|
||||
replied_nickname = ""
|
||||
replied_content = "" # 被引用者说的内容
|
||||
content = "" # 引用者说的内容
|
||||
|
||||
root = self._format_to_xml()
|
||||
refermsg = root.find(".//refermsg")
|
||||
if refermsg is not None:
|
||||
# 被引用的信息
|
||||
svrid = refermsg.find("svrid")
|
||||
fromusr = refermsg.find("fromusr")
|
||||
displayname = refermsg.find("displayname")
|
||||
refermsg_content = refermsg.find("content")
|
||||
if svrid is not None:
|
||||
replied_id = svrid.text
|
||||
if fromusr is not None:
|
||||
replied_uid = fromusr.text
|
||||
if displayname is not None:
|
||||
replied_nickname = displayname.text
|
||||
if refermsg_content is not None:
|
||||
# 处理引用嵌套,包括嵌套公众号消息
|
||||
if refermsg_content.text.startswith(
|
||||
"<msg>"
|
||||
) or refermsg_content.text.startswith("<?xml"):
|
||||
try:
|
||||
logger.debug("gewechat: Reference message is nested")
|
||||
refer_root = eT.fromstring(refermsg_content.text)
|
||||
img = refer_root.find("img")
|
||||
if img is not None:
|
||||
replied_content = "[图片]"
|
||||
else:
|
||||
app_msg = refer_root.find("appmsg")
|
||||
refermsg_content_title = app_msg.find("title")
|
||||
logger.debug(
|
||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||
)
|
||||
replied_content = refermsg_content_title.text
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: nested failed, {e}")
|
||||
# 处理异常情况
|
||||
replied_content = refermsg_content.text
|
||||
else:
|
||||
replied_content = refermsg_content.text
|
||||
|
||||
# 提取引用者说的内容
|
||||
title = root.find(".//appmsg/title")
|
||||
if title is not None:
|
||||
content = title.text
|
||||
|
||||
reply_seg = Reply(
|
||||
id=replied_id,
|
||||
chain=[Plain(replied_content)],
|
||||
sender_id=replied_uid,
|
||||
sender_nickname=replied_nickname,
|
||||
message_str=replied_content,
|
||||
)
|
||||
plain_seg = Plain(content)
|
||||
return [reply_seg, plain_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||
@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
return base64_content
|
||||
else:
|
||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
||||
logger.error(
|
||||
f"Failed to download slack file: {resp.status} {await resp.text()}"
|
||||
)
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
|
||||
@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
|
||||
async def send_with_client(
|
||||
cls, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
|
||||
@@ -22,7 +22,11 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
if not message:
|
||||
await web_chat_back_queue.put(
|
||||
{"type": "end", "data": "", "streaming": False}
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
} # end means this request is finished
|
||||
)
|
||||
return ""
|
||||
|
||||
@@ -99,16 +103,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
"cid": cid,
|
||||
}
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
@@ -120,7 +114,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
@@ -134,7 +128,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class WebChatQueueMgr:
|
||||
def __init__(self) -> None:
|
||||
self.queues = {}
|
||||
@@ -30,4 +31,5 @@ class WebChatQueueMgr:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
|
||||
webchat_queue_mgr = WebChatQueueMgr()
|
||||
|
||||
@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
|
||||
def _extract_auth_key(self, data):
|
||||
"""Helper method to extract auth_key from response data."""
|
||||
if isinstance(data, dict):
|
||||
auth_keys = data.get("authKeys") # 新接口
|
||||
auth_keys = data.get("authKeys") # 新接口
|
||||
if isinstance(auth_keys, list) and auth_keys:
|
||||
return auth_keys[0]
|
||||
elif isinstance(data, list) and data: # 旧接口
|
||||
elif isinstance(data, list) and data: # 旧接口
|
||||
return data[0]
|
||||
return None
|
||||
|
||||
@@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform):
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
|
||||
logger.error(
|
||||
f"生成授权码失败: {response.status}, {await response.text()}"
|
||||
)
|
||||
return
|
||||
|
||||
response_data = await response.json()
|
||||
@@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform):
|
||||
if self.auth_key:
|
||||
logger.info("成功获取授权码")
|
||||
else:
|
||||
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
|
||||
logger.error(
|
||||
f"生成授权码成功但未找到授权码: {response_data}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"生成授权码失败: {response_data}")
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
|
||||
@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
|
||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||
data = {
|
||||
"token": token,
|
||||
"cursor": cursor,
|
||||
"limit": limit,
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
return self._post("kf/sync_msg", data=data)
|
||||
|
||||
def get_service_state(self, open_kfid, external_userid):
|
||||
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
}
|
||||
return self._post("kf/service_state/get", data=data)
|
||||
|
||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||
def trans_service_state(
|
||||
self, open_kfid, external_userid, service_state, servicer_userid=""
|
||||
):
|
||||
"""
|
||||
变更会话状态
|
||||
|
||||
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
"""
|
||||
return self._get("kf/customer/get_upgrade_service_config")
|
||||
|
||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||
def upgrade_service(
|
||||
self, open_kfid, external_userid, service_type, member=None, groupchat=None
|
||||
):
|
||||
"""
|
||||
为客户升级为专员或客户群服务
|
||||
|
||||
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
|
||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||
return self._post("kf/get_corp_statistic", data=data)
|
||||
|
||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||
def get_servicer_statistic(
|
||||
self, start_time, end_time, open_kfid=None, servicer_userid=None
|
||||
):
|
||||
"""
|
||||
获取「客户数据统计」接待人员明细数据
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from optionaldict import optionaldict
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
|
||||
class WeChatKFMessage(BaseWeChatAPI):
|
||||
"""
|
||||
发送微信客服消息
|
||||
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
|
||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||
)
|
||||
|
||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||
def send_msgmenu(
|
||||
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "msgmenu",
|
||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||
"msgmenu": {
|
||||
"head_content": head_content,
|
||||
"list": menu_list,
|
||||
"tail_content": tail_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||
def send_location(
|
||||
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "location",
|
||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||
"msgmenu": {
|
||||
"name": name,
|
||||
"address": address,
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||
def send_miniprogram(
|
||||
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
|
||||
):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "miniprogram",
|
||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||
"msgmenu": {
|
||||
"appid": appid,
|
||||
"title": title,
|
||||
"thumb_media_id": thumb_media_id,
|
||||
"pagepath": pagepath,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.shield(future), 60
|
||||
) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
|
||||
@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
|
||||
@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
@@ -80,12 +146,10 @@ class FuncTool:
|
||||
if not self.mcp_client or not self.mcp_client.session:
|
||||
raise Exception(f"MCP client for {self.name} is not available")
|
||||
# 使用name属性而不是额外的mcp_tool_name
|
||||
if ":" in self.name:
|
||||
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||
actual_tool_name = self.name.split(":")[-1]
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
return await self.mcp_client.session.call_tool(self.name, args)
|
||||
actual_tool_name = (
|
||||
self.name.split(":")[-1] if ":" in self.name else self.name
|
||||
)
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
|
||||
@@ -100,6 +164,7 @@ class MCPClient:
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
@@ -112,17 +177,19 @@ class MCPClient:
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = mcp_server_config.copy()
|
||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||
cfg = cfg["mcpServers"][key_0]
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
@@ -130,11 +197,18 @@ class MCPClient:
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
@@ -148,11 +222,19 @@ class MCPClient:
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -172,7 +254,7 @@ class MCPClient:
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
@@ -180,19 +262,18 @@ class MCPClient:
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
|
||||
|
||||
class FuncCall:
|
||||
@@ -201,8 +282,6 @@ class FuncCall:
|
||||
"""内部加载的 func tools"""
|
||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_service_queue = asyncio.Queue()
|
||||
"""用于外部控制 MCP 服务的启停"""
|
||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
@@ -258,7 +337,7 @@ class FuncCall:
|
||||
return f
|
||||
return None
|
||||
|
||||
async def _init_mcp_clients(self) -> None:
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
{
|
||||
@@ -300,113 +379,64 @@ class FuncCall:
|
||||
)
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
async def mcp_service_selector(self):
|
||||
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||
|
||||
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||
|
||||
{"type": "init"} 初始化所有MCP客户端
|
||||
|
||||
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||
|
||||
{"type": "terminate"} 终止所有MCP客户端
|
||||
|
||||
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||
"""
|
||||
while True:
|
||||
data = await self.mcp_service_queue.get()
|
||||
if data["type"] == "init":
|
||||
if "name" in data:
|
||||
event = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(
|
||||
data["name"], data["cfg"], event
|
||||
)
|
||||
)
|
||||
self.mcp_client_event[data["name"]] = event
|
||||
else:
|
||||
await self._init_mcp_clients()
|
||||
elif data["type"] == "terminate":
|
||||
if "name" in data:
|
||||
# await self._terminate_mcp_client(data["name"])
|
||||
if data["name"] in self.mcp_client_event:
|
||||
self.mcp_client_event[data["name"]].set()
|
||||
self.mcp_client_event.pop(data["name"], None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (
|
||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
for name in self.mcp_client_dict.keys():
|
||||
# await self._terminate_mcp_client(name)
|
||||
# self.mcp_client_event[name].set()
|
||||
if name in self.mcp_client_event:
|
||||
self.mcp_client_event[name].set()
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self, name: str, cfg: dict, event: asyncio.Event
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
await self._terminate_mcp_client(name)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
finally:
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
"""初始化单个MCP客户端"""
|
||||
try:
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||
]
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||
]
|
||||
|
||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||
for tool in mcp_client.tools:
|
||||
func_tool = FuncTool(
|
||||
name=tool.name,
|
||||
parameters=tool.inputSchema,
|
||||
description=tool.description,
|
||||
origin="mcp",
|
||||
mcp_server_name=name,
|
||||
mcp_client=mcp_client,
|
||||
)
|
||||
self.func_list.append(func_tool)
|
||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||
for tool in mcp_client.tools:
|
||||
func_tool = FuncTool(
|
||||
name=tool.name,
|
||||
parameters=tool.inputSchema,
|
||||
description=tool.description,
|
||||
origin="mcp",
|
||||
mcp_server_name=name,
|
||||
mcp_client=mcp_client,
|
||||
)
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -414,9 +444,9 @@ class FuncCall:
|
||||
try:
|
||||
# 关闭MCP连接
|
||||
await self.mcp_client_dict[name].cleanup()
|
||||
del self.mcp_client_dict[name]
|
||||
self.mcp_client_dict.pop(name)
|
||||
except Exception as e:
|
||||
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
@@ -425,6 +455,103 @@ class FuncCall:
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
if "url" in config:
|
||||
success, error_msg = await _quick_test_mcp_connection(config)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
try:
|
||||
logger.debug(f"testing MCP server connection with config: {config}")
|
||||
await mcp_client.connect_to_server(config, "test")
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
finally:
|
||||
logger.debug("Cleaning up MCP client after testing connection.")
|
||||
await mcp_client.cleanup()
|
||||
return tool_names
|
||||
|
||||
async def enable_mcp_server(
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
Raises:
|
||||
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str | None = None, timeout: float = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server by its name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout (int): Timeout.
|
||||
"""
|
||||
if name:
|
||||
if name not in self.mcp_client_event:
|
||||
return
|
||||
client = self.mcp_client_dict.get(name)
|
||||
self.mcp_client_event[name].set()
|
||||
if not client:
|
||||
return
|
||||
client_running_event = client.running_event
|
||||
try:
|
||||
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if f.origin != "mcp" or f.mcp_server_name != name
|
||||
]
|
||||
else:
|
||||
running_events = [
|
||||
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||
]
|
||||
for key, event in self.mcp_client_event.items():
|
||||
event.set()
|
||||
# waiting for all clients to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.clear()
|
||||
self.mcp_client_dict.clear()
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
@@ -629,8 +756,3 @@ class FuncCall:
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
async def terminate(self):
|
||||
for name in self.mcp_client_dict.keys():
|
||||
await self._terminate_mcp_client(name)
|
||||
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
||||
|
||||
@@ -7,7 +7,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map: dict[str, Provider] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
@@ -169,10 +169,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
||||
)
|
||||
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
@@ -422,7 +419,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = None
|
||||
|
||||
if getattr(self.inst_map[provider_id], "terminate", None):
|
||||
await self.inst_map[provider_id].terminate() # type: ignore
|
||||
await self.inst_map[provider_id].terminate() # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
||||
@@ -432,6 +429,8 @@ class ProviderManager:
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
await provider_inst.terminate() # type: ignore
|
||||
# 清理 MCP Client 连接
|
||||
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
||||
await provider_inst.terminate() # type: ignore
|
||||
try:
|
||||
await self.llm_tools.disable_mcp_server()
|
||||
except Exception:
|
||||
logger.error("Error while disabling MCP servers", exc_info=True)
|
||||
|
||||
@@ -2,7 +2,8 @@ import abc
|
||||
from typing import List
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -22,6 +23,7 @@ class ProviderMeta:
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
provider_type: ProviderType
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -40,10 +42,14 @@ class AbstractProvider(abc.ABC):
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
"""获取 Provider 的元数据"""
|
||||
provider_type_name = self.provider_config["type"]
|
||||
meta_data = provider_cls_map.get(provider_type_name)
|
||||
provider_type = meta_data.provider_type if meta_data else None
|
||||
return ProviderMeta(
|
||||
id=self.provider_config["id"],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config["type"],
|
||||
type=provider_type_name,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ class ProviderDify(Provider):
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
|
||||
@@ -470,6 +470,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
raise
|
||||
continue
|
||||
|
||||
# Accumulate the complete response text for the final response
|
||||
accumulated_text = ""
|
||||
final_response = None
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
@@ -481,23 +485,37 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
return
|
||||
|
||||
if chunk.text:
|
||||
accumulated_text += chunk.text
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
if not chunk.candidates[0].content.parts:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||
else:
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
# Process the final chunk for potential tool calls or other content
|
||||
if chunk.candidates[0].content.parts:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
final_response.result_chain = self._process_content_parts(
|
||||
chunk, final_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
# Yield final complete response with accumulated text
|
||||
if not final_response:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
|
||||
# Set the complete accumulated text in the final response
|
||||
if accumulated_text:
|
||||
final_response.result_chain = MessageChain(
|
||||
chain=[Comp.Plain(accumulated_text)]
|
||||
)
|
||||
elif not final_response.result_chain:
|
||||
# If no text was accumulated and no final response was set, provide empty space
|
||||
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||
|
||||
yield final_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
|
||||
@@ -99,6 +99,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for key in to_del:
|
||||
del payloads[key]
|
||||
|
||||
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
|
||||
model = payloads.get("model", "")
|
||||
if "qwen3" in model.lower():
|
||||
extra_body["enable_thinking"] = False
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
@@ -176,7 +181,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
if choice.message.content:
|
||||
if choice.message.content is not None:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
@@ -187,6 +192,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
func_name_ls = []
|
||||
tool_call_ids = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
# workaround for #1454
|
||||
@@ -207,7 +215,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
|
||||
)
|
||||
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
@@ -482,13 +490,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"""
|
||||
new_contexts = []
|
||||
|
||||
flag = False
|
||||
for context in contexts:
|
||||
if flag:
|
||||
flag = False # 删除 image 后,下一条(LLM 响应)也要删除
|
||||
continue
|
||||
if "content" in context and isinstance(context["content"], list):
|
||||
flag = True
|
||||
# continue
|
||||
new_content = []
|
||||
for item in context["content"]:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .star import StarMetadata
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
from .star_manager import PluginManager
|
||||
from .context import Context
|
||||
from astrbot.core.provider import Provider
|
||||
@@ -10,25 +10,48 @@ from astrbot.core.star.star_tools import StarTools
|
||||
class Star(CommandParserMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
def __init__(self, context: Context, config: dict | None = None):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not star_map.get(cls.__module__):
|
||||
metadata = StarMetadata(
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__,
|
||||
)
|
||||
star_map[cls.__module__] = metadata
|
||||
star_registry.append(metadata)
|
||||
else:
|
||||
star_map[cls.__module__].star_cls_type = cls
|
||||
star_map[cls.__module__].module_path = cls.__module__
|
||||
|
||||
@staticmethod
|
||||
async def text_to_image(text: str, return_url=True) -> str:
|
||||
"""将文本转换为图片"""
|
||||
return await html_renderer.render_t2i(text, return_url=return_url)
|
||||
|
||||
@staticmethod
|
||||
async def html_render(
|
||||
self, tmpl: str, data: dict, return_url=True, options: dict = None
|
||||
tmpl: str, data: dict, return_url=True, options: dict = None
|
||||
) -> str:
|
||||
"""渲染 HTML"""
|
||||
return await html_renderer.render_custom_template(
|
||||
tmpl, data, return_url=return_url, options=options
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""当插件被激活时会调用这个方法"""
|
||||
pass
|
||||
|
||||
async def terminate(self):
|
||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||
|
||||
@@ -2,7 +2,12 @@ from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider,
|
||||
TTSProvider,
|
||||
STTProvider,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -141,6 +146,10 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
)
|
||||
raise ValueError(
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||
+ tree
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||
)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
|
||||
@@ -8,22 +8,45 @@ from typing import Union
|
||||
class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
VCHAT = enum.auto()
|
||||
GEWECHAT = enum.auto()
|
||||
TELEGRAM = enum.auto()
|
||||
WECOM = enum.auto()
|
||||
LARK = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT | TELEGRAM | WECOM | LARK
|
||||
WECHATPADPRO = enum.auto()
|
||||
DINGTALK = enum.auto()
|
||||
DISCORD = enum.auto()
|
||||
SLACK = enum.auto()
|
||||
KOOK = enum.auto()
|
||||
VOCECHAT = enum.auto()
|
||||
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||
ALL = (
|
||||
AIOCQHTTP
|
||||
| QQOFFICIAL
|
||||
| TELEGRAM
|
||||
| WECOM
|
||||
| LARK
|
||||
| WECHATPADPRO
|
||||
| DINGTALK
|
||||
| DISCORD
|
||||
| SLACK
|
||||
| KOOK
|
||||
| VOCECHAT
|
||||
| WEIXIN_OFFICIAL_ACCOUNT
|
||||
)
|
||||
|
||||
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"vchat": PlatformAdapterType.VCHAT,
|
||||
"gewechat": PlatformAdapterType.GEWECHAT,
|
||||
"telegram": PlatformAdapterType.TELEGRAM,
|
||||
"wecom": PlatformAdapterType.WECOM,
|
||||
"lark": PlatformAdapterType.LARK,
|
||||
"dingtalk": PlatformAdapterType.DINGTALK,
|
||||
"discord": PlatformAdapterType.DISCORD,
|
||||
"slack": PlatformAdapterType.SLACK,
|
||||
"kook": PlatformAdapterType.KOOK,
|
||||
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
||||
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from ..star import star_registry, StarMetadata, star_map
|
||||
import warnings
|
||||
|
||||
from astrbot.core.star import StarMetadata, star_map
|
||||
|
||||
_warned_register_star = False
|
||||
|
||||
|
||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
||||
"""注册一个插件(Star)。
|
||||
|
||||
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||
在 v3.5.19 版本之后(不含),您不需要使用该装饰器来装饰插件类,
|
||||
AstrBot 会自动识别继承自 Star 的类并将其作为插件类加载。
|
||||
|
||||
Args:
|
||||
name: 插件名称。
|
||||
author: 作者。
|
||||
@@ -21,18 +29,32 @@ def register_star(name: str, author: str, desc: str, version: str, repo: str = N
|
||||
帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。`
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
star_metadata = StarMetadata(
|
||||
name=name,
|
||||
author=author,
|
||||
desc=desc,
|
||||
version=version,
|
||||
repo=repo,
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__,
|
||||
global _warned_register_star
|
||||
if not _warned_register_star:
|
||||
_warned_register_star = True
|
||||
warnings.warn(
|
||||
"The 'register_star' decorator is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
star_registry.append(star_metadata)
|
||||
star_map[cls.__module__] = star_metadata
|
||||
|
||||
def decorator(cls):
|
||||
if not star_map.get(cls.__module__):
|
||||
metadata = StarMetadata(
|
||||
name=name,
|
||||
author=author,
|
||||
desc=desc,
|
||||
version=version,
|
||||
repo=repo,
|
||||
)
|
||||
star_map[cls.__module__] = metadata
|
||||
else:
|
||||
star_map[cls.__module__].name = name
|
||||
star_map[cls.__module__].author = author
|
||||
star_map[cls.__module__].desc = desc
|
||||
star_map[cls.__module__].version = version
|
||||
star_map[cls.__module__].repo = repo
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
293
astrbot/core/star/session_llm_manager.py
Normal file
293
astrbot/core/star/session_llm_manager.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionServiceManager:
|
||||
"""管理会话级别的服务启停状态,包括LLM和TTS"""
|
||||
|
||||
# =============================================================================
|
||||
# LLM 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_services.get("llm_enabled")
|
||||
if llm_enabled is not None:
|
||||
return llm_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置LLM状态
|
||||
session_config[session_id]["llm_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# TTS 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查TTS是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
tts_enabled = session_services.get("tts_enabled")
|
||||
if tts_enabled is not None:
|
||||
return tts_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置TTS在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置TTS状态
|
||||
session_config[session_id]["tts_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理TTS请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话整体启停相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_session_enabled(session_id: str) -> bool:
|
||||
"""检查会话是否整体启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
session_enabled = session_services.get("session_enabled")
|
||||
if session_enabled is not None:
|
||||
return session_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_session_status(session_id: str, enabled: bool) -> None:
|
||||
"""设置会话的整体启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置会话整体状态
|
||||
session_config[session_id]["session_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理会话请求(会话整体启停检查)
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_session_enabled(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话命名相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
||||
"""设置会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置自定义名称
|
||||
if custom_name and custom_name.strip():
|
||||
session_config[session_id]["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config[session_id].pop("custom_name", None)
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_display_name(session_id: str) -> str:
|
||||
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 显示名称
|
||||
"""
|
||||
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||
if custom_name:
|
||||
return custom_name
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
# =============================================================================
|
||||
# 通用配置方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_service_config(session_id: str) -> Dict[str, bool]:
|
||||
"""获取指定会话的服务配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
return session_config.get(
|
||||
session_id,
|
||||
{
|
||||
"session_enabled": True, # 默认启用
|
||||
"llm_enabled": True, # 默认启用
|
||||
"tts_enabled": True, # 默认启用
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
|
||||
"""获取所有会话的服务配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, bool]]: 所有会话的服务配置
|
||||
"""
|
||||
return sp.get("session_service_config", {}) or {}
|
||||
142
astrbot/core/star/session_plugin_manager.py
Normal file
142
astrbot/core/star/session_plugin_manager.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
会话插件管理器 - 负责管理每个会话的插件启停状态
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict, List
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionPluginManager:
|
||||
"""管理会话级别的插件启停状态"""
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||
"""检查插件是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
# 如果插件在禁用列表中,返回False
|
||||
if plugin_name in disabled_plugins:
|
||||
return False
|
||||
|
||||
# 如果插件在启用列表中,返回True
|
||||
if plugin_name in enabled_plugins:
|
||||
return True
|
||||
|
||||
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_plugin_status_for_session(
|
||||
session_id: str, plugin_name: str, enabled: bool
|
||||
) -> None:
|
||||
"""设置插件在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": [],
|
||||
}
|
||||
|
||||
session_config = session_plugin_config[session_id]
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
if enabled:
|
||||
# 启用插件
|
||||
if plugin_name in disabled_plugins:
|
||||
disabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in enabled_plugins:
|
||||
enabled_plugins.append(plugin_name)
|
||||
else:
|
||||
# 禁用插件
|
||||
if plugin_name in enabled_plugins:
|
||||
enabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in disabled_plugins:
|
||||
disabled_plugins.append(plugin_name)
|
||||
|
||||
# 保存配置
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put("session_plugin_config", session_plugin_config)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]:
|
||||
"""获取指定会话的插件配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
"""
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
return session_plugin_config.get(
|
||||
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
handlers: 原始处理器列表
|
||||
|
||||
Returns:
|
||||
List: 过滤后的处理器列表
|
||||
"""
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
session_id = event.unified_msg_origin
|
||||
filtered_handlers = []
|
||||
|
||||
for handler in handlers:
|
||||
# 获取处理器对应的插件
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not plugin:
|
||||
# 如果找不到插件元数据,允许执行(可能是系统插件)
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
# 跳过保留插件(系统插件)
|
||||
if plugin.reserved:
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin.name
|
||||
):
|
||||
filtered_handlers.append(handler)
|
||||
else:
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}"
|
||||
)
|
||||
|
||||
return filtered_handlers
|
||||
@@ -1,14 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
star_map: Dict[str, StarMetadata] = {}
|
||||
star_registry: list[StarMetadata] = []
|
||||
star_map: dict[str, StarMetadata] = {}
|
||||
"""key 是模块路径,__module__"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Star
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
@@ -18,22 +22,27 @@ class StarMetadata:
|
||||
当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。
|
||||
"""
|
||||
|
||||
name: str
|
||||
author: str # 插件作者
|
||||
desc: str # 插件简介
|
||||
version: str # 插件版本
|
||||
repo: str = None # 插件仓库地址
|
||||
name: str | None = None
|
||||
"""插件名"""
|
||||
author: str | None = None
|
||||
"""插件作者"""
|
||||
desc: str | None = None
|
||||
"""插件简介"""
|
||||
version: str | None = None
|
||||
"""插件版本"""
|
||||
repo: str | None = None
|
||||
"""插件仓库地址"""
|
||||
|
||||
star_cls_type: type = None
|
||||
star_cls_type: type[Star] | None = None
|
||||
"""插件的类对象的类型"""
|
||||
module_path: str = None
|
||||
module_path: str | None = None
|
||||
"""插件的模块路径"""
|
||||
|
||||
star_cls: object = None
|
||||
star_cls: Star | None = None
|
||||
"""插件的类对象"""
|
||||
module: ModuleType = None
|
||||
module: ModuleType | None = None
|
||||
"""插件的模块对象"""
|
||||
root_dir_name: str = None
|
||||
root_dir_name: str | None = None
|
||||
"""插件的目录名称"""
|
||||
reserved: bool = False
|
||||
"""是否是 AstrBot 的保留插件"""
|
||||
@@ -41,17 +50,20 @@ class StarMetadata:
|
||||
activated: bool = True
|
||||
"""是否被激活"""
|
||||
|
||||
config: AstrBotConfig = None
|
||||
config: AstrBotConfig | None = None
|
||||
"""插件配置"""
|
||||
|
||||
star_handler_full_names: List[str] = field(default_factory=list)
|
||||
star_handler_full_names: list[str] = field(default_factory=list)
|
||||
"""注册的 Handler 的全名列表"""
|
||||
|
||||
supported_platforms: Dict[str, bool] = field(default_factory=dict)
|
||||
supported_platforms: dict[str, bool] = field(default_factory=dict)
|
||||
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||
|
||||
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||
"""更新插件支持的平台列表
|
||||
|
||||
@@ -7,6 +7,7 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -49,7 +50,8 @@ class StarHandlerRegistry(Generic[T]):
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
return [
|
||||
handler for handler in self._handlers
|
||||
handler
|
||||
for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
@@ -67,6 +69,7 @@ class StarHandlerRegistry(Generic[T]):
|
||||
def __len__(self):
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -37,12 +36,6 @@ except ImportError:
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig):
|
||||
@@ -64,6 +57,8 @@ class PluginManager:
|
||||
"""保留插件的路径。在 packages 目录下"""
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
"""插件配置 Schema 文件名"""
|
||||
self._pm_lock = asyncio.Lock()
|
||||
"""StarManager操作互斥锁"""
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
@@ -119,7 +114,8 @@ class PluginManager:
|
||||
reloaded_plugins.add(plugin_name)
|
||||
break
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
@staticmethod
|
||||
def _get_classes(arg: ModuleType):
|
||||
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
||||
classes = []
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
@@ -129,7 +125,8 @@ class PluginManager:
|
||||
break
|
||||
return classes
|
||||
|
||||
def _get_modules(self, path):
|
||||
@staticmethod
|
||||
def _get_modules(path):
|
||||
modules = []
|
||||
|
||||
dirs = os.listdir(path)
|
||||
@@ -155,7 +152,7 @@ class PluginManager:
|
||||
)
|
||||
return modules
|
||||
|
||||
def _get_plugin_modules(self) -> List[dict]:
|
||||
def _get_plugin_modules(self) -> list[dict]:
|
||||
plugins = []
|
||||
if os.path.exists(self.plugin_store_path):
|
||||
plugins.extend(self._get_modules(self.plugin_store_path))
|
||||
@@ -166,7 +163,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str | None = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -189,10 +186,11 @@ class PluginManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj=None) -> StarMetadata:
|
||||
"""v3.4.0 以前的方式载入插件元数据
|
||||
@staticmethod
|
||||
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
|
||||
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
|
||||
先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
|
||||
"""
|
||||
metadata = None
|
||||
|
||||
@@ -204,11 +202,14 @@ class PluginManager:
|
||||
os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8"
|
||||
) as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
elif plugin_obj:
|
||||
elif plugin_obj and hasattr(plugin_obj, "info"):
|
||||
# 使用 info() 函数
|
||||
metadata = plugin_obj.info()
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
if "desc" not in metadata and "description" in metadata:
|
||||
metadata["desc"] = metadata["description"]
|
||||
|
||||
if (
|
||||
"name" not in metadata
|
||||
or "desc" not in metadata
|
||||
@@ -228,8 +229,9 @@ class PluginManager:
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_related_modules(
|
||||
self, plugin_root_dir: str, is_reserved: bool = False
|
||||
plugin_root_dir: str, is_reserved: bool = False
|
||||
) -> list[str]:
|
||||
"""获取与指定插件相关的所有已加载模块名
|
||||
|
||||
@@ -251,8 +253,8 @@ class PluginManager:
|
||||
|
||||
def _purge_modules(
|
||||
self,
|
||||
module_patterns: list[str] = None,
|
||||
root_dir_name: str = None,
|
||||
module_patterns: list[str] | None = None,
|
||||
root_dir_name: str | None = None,
|
||||
is_reserved: bool = False,
|
||||
):
|
||||
"""从 sys.modules 中移除指定的模块
|
||||
@@ -293,50 +295,51 @@ class PluginManager:
|
||||
- success (bool): 重载是否成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
if smd.name == specified_plugin_name:
|
||||
specified_module_path = smd.module_path
|
||||
break
|
||||
async with self._pm_lock:
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
if smd.name == specified_plugin_name:
|
||||
specified_module_path = smd.module_path
|
||||
break
|
||||
|
||||
# 终止插件
|
||||
if not specified_module_path:
|
||||
# 重载所有插件
|
||||
for smd in star_registry:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
# 终止插件
|
||||
if not specified_module_path:
|
||||
# 重载所有插件
|
||||
for smd in star_registry:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
if smd.name and smd.module_path:
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
if smd:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
if smd.name:
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
if smd:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
# 更新所有插件的平台兼容性
|
||||
await self.update_all_platform_compatibility()
|
||||
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
# 更新所有插件的平台兼容性
|
||||
await self.update_all_platform_compatibility()
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
async def update_all_platform_compatibility(self):
|
||||
"""更新所有插件的平台兼容性设置"""
|
||||
@@ -435,7 +438,7 @@ class PluginManager:
|
||||
)
|
||||
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
# 通过 __init__subclass__ 注册插件
|
||||
metadata = star_map[path]
|
||||
|
||||
try:
|
||||
@@ -449,13 +452,15 @@ class PluginManager:
|
||||
metadata.desc = metadata_yaml.desc
|
||||
metadata.version = metadata_yaml.version
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"插件 {root_dir_name} 元数据载入失败: {str(e)}。使用默认元数据。"
|
||||
)
|
||||
logger.info(metadata)
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
# metadata.config = plugin_config
|
||||
if plugin_config and metadata.star_cls_type:
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -464,7 +469,7 @@ class PluginManager:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
else:
|
||||
elif metadata.star_cls_type:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
@@ -481,6 +486,10 @@ class PluginManager:
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = (
|
||||
star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -489,7 +498,8 @@ class PluginManager:
|
||||
)
|
||||
for handler in related_handlers:
|
||||
handler.handler = functools.partial(
|
||||
handler.handler, metadata.star_cls
|
||||
handler.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
@@ -499,7 +509,8 @@ class PluginManager:
|
||||
):
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(
|
||||
func_tool.handler, metadata.star_cls
|
||||
func_tool.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
@@ -526,13 +537,12 @@ class PluginManager:
|
||||
obj = getattr(module, classes[0])(
|
||||
context=self.context
|
||||
) # 实例化插件类
|
||||
else:
|
||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||
|
||||
metadata = None
|
||||
metadata = self._load_plugin_metadata(
|
||||
plugin_path=plugin_dir_path, plugin_obj=obj
|
||||
)
|
||||
if not metadata:
|
||||
raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
@@ -547,6 +557,10 @@ class PluginManager:
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
full_names = []
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||
metadata.module_path
|
||||
@@ -586,7 +600,7 @@ class PluginManager:
|
||||
metadata.star_handler_full_names = full_names
|
||||
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
@@ -622,43 +636,45 @@ class PluginManager:
|
||||
- readme: README.md 文件的内容(如果存在)
|
||||
如果找不到插件元数据则返回 None。
|
||||
"""
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
await self.load(specified_dir_name=dir_name)
|
||||
async with self._pm_lock:
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
await self.load(specified_dir_name=dir_name)
|
||||
|
||||
# Get the plugin metadata to return repo info
|
||||
plugin = self.context.get_registered_star(dir_name)
|
||||
if not plugin:
|
||||
# Try to find by other name if directory name doesn't match plugin name
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
plugin = star
|
||||
break
|
||||
# Get the plugin metadata to return repo info
|
||||
plugin = self.context.get_registered_star(dir_name)
|
||||
if not plugin:
|
||||
# Try to find by other name if directory name doesn't match plugin name
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
plugin = star
|
||||
break
|
||||
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(plugin_path, "README.md")
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(plugin_path, "README.md")
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path) and nh3:
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
if os.path.exists(readme_path):
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}"
|
||||
)
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {
|
||||
"repo": plugin.repo,
|
||||
"readme": cleaned_content,
|
||||
"name": plugin.name,
|
||||
}
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {
|
||||
"repo": plugin.repo,
|
||||
"readme": readme_content,
|
||||
"name": plugin.name,
|
||||
}
|
||||
|
||||
return plugin_info
|
||||
return plugin_info
|
||||
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
"""卸载指定的插件。
|
||||
@@ -669,32 +685,33 @@ class PluginManager:
|
||||
Raises:
|
||||
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||
"""
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
async with self._pm_lock:
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# 终止插件
|
||||
try:
|
||||
await self._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
|
||||
)
|
||||
# 终止插件
|
||||
try:
|
||||
await self._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
|
||||
)
|
||||
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
|
||||
try:
|
||||
remove_dir(os.path.join(ppath, root_dir_name))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
||||
)
|
||||
try:
|
||||
remove_dir(os.path.join(ppath, root_dir_name))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
||||
)
|
||||
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
"""解绑并移除一个插件。
|
||||
@@ -725,6 +742,9 @@ class PluginManager:
|
||||
]:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
if plugin is None:
|
||||
return
|
||||
|
||||
self._purge_modules(
|
||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||
)
|
||||
@@ -747,35 +767,37 @@ class PluginManager:
|
||||
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
|
||||
并且同时将插件启用的 llm_tool 禁用。
|
||||
"""
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
async with self._pm_lock:
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
|
||||
# 调用插件的终止方法
|
||||
await self._terminate_plugin(plugin)
|
||||
# 调用插件的终止方法
|
||||
await self._terminate_plugin(plugin)
|
||||
|
||||
# 加入到 shared_preferences 中
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
# 加入到 shared_preferences 中
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = list(
|
||||
set(sp.get("inactivated_llm_tools", []))
|
||||
) # 后向兼容
|
||||
inactivated_llm_tools: list = list(
|
||||
set(sp.get("inactivated_llm_tools", []))
|
||||
) # 后向兼容
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = False
|
||||
plugin.activated = False
|
||||
|
||||
async def _terminate_plugin(self, star_metadata: StarMetadata):
|
||||
@staticmethod
|
||||
async def _terminate_plugin(star_metadata: StarMetadata):
|
||||
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
||||
logger.info(f"正在终止插件 {star_metadata.name} ...")
|
||||
|
||||
@@ -784,11 +806,14 @@ class PluginManager:
|
||||
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
||||
return
|
||||
|
||||
if hasattr(star_metadata.star_cls, "__del__"):
|
||||
if star_metadata.star_cls is None:
|
||||
return
|
||||
|
||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
)
|
||||
elif hasattr(star_metadata.star_cls, "terminate"):
|
||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||
await star_metadata.star_cls.terminate()
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
|
||||
@@ -182,7 +182,9 @@ class StarTools:
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
data_dir = Path(
|
||||
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
|
||||
)
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -56,9 +56,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
try:
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
if os.name == "nt":
|
||||
args = [
|
||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||
]
|
||||
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
|
||||
else:
|
||||
args = sys.argv[1:]
|
||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||
|
||||
@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
|
||||
raise exc_info[1]
|
||||
|
||||
|
||||
def remove_dir(file_path) -> bool:
|
||||
def remove_dir(file_path: str) -> bool:
|
||||
if not os.path.exists(file_path):
|
||||
return True
|
||||
shutil.rmtree(file_path, onerror=on_error)
|
||||
|
||||
29
astrbot/core/utils/session_lock.py
Normal file
29
astrbot/core/utils/session_lock.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
class SessionLockManager:
|
||||
def __init__(self):
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._lock_count: dict[str, int] = defaultdict(int)
|
||||
self._access_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_lock(self, session_id: str):
|
||||
async with self._access_lock:
|
||||
lock = self._locks[session_id]
|
||||
self._lock_count[session_id] += 1
|
||||
|
||||
try:
|
||||
async with lock:
|
||||
yield
|
||||
finally:
|
||||
async with self._access_lock:
|
||||
self._lock_count[session_id] -= 1
|
||||
if self._lock_count[session_id] == 0:
|
||||
self._locks.pop(session_id, None)
|
||||
self._lock_count.pop(session_id, None)
|
||||
|
||||
|
||||
session_lock_manager = SessionLockManager()
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
@@ -24,7 +27,7 @@ class SharedPreferences:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
|
||||
@@ -117,7 +117,7 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
|
||||
try:
|
||||
import pilk
|
||||
except ImportError as e:
|
||||
raise Exception("未安装 pysilk,请执行: pip install pysilk") from e
|
||||
raise Exception("未安装 pilk: pip install pilk") from e
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
@@ -9,6 +9,7 @@ from .chat import ChatRoute
|
||||
from .tools import ToolsRoute # 导入新的ToolsRoute
|
||||
from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
from .session_management import SessionManagementRoute
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"ToolsRoute",
|
||||
"ConversationRoute",
|
||||
"FileRoute",
|
||||
"SessionManagementRoute",
|
||||
]
|
||||
|
||||
@@ -166,15 +166,12 @@ class ChatRoute(Route):
|
||||
type = result.get("type")
|
||||
cid = result.get("cid")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if streaming and type != "end":
|
||||
# If the result is still streaming, we continue to wait for more data
|
||||
continue
|
||||
|
||||
if result_text:
|
||||
if type == "end":
|
||||
break
|
||||
elif (streaming and type == "complete") or not streaming:
|
||||
# append bot message
|
||||
conversation = self.db.get_conversation_by_user_id(
|
||||
username, cid
|
||||
@@ -188,10 +185,6 @@ class ChatRoute(Route):
|
||||
self.db.update_conversation(
|
||||
username, cid, history=json.dumps(history)
|
||||
)
|
||||
if chain_type not in ["tool_call", "tool_call_result"]:
|
||||
# If the result is not a tool call or tool call result,
|
||||
# we can break the loop and end the stream
|
||||
break
|
||||
|
||||
except BaseException as _:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
@@ -187,15 +190,12 @@ class ConfigRoute(Route):
|
||||
"""辅助函数:测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
provider_name = provider.provider_config.get("id", "Unknown Provider")
|
||||
logger.debug(f"Got provider meta: {meta}")
|
||||
if not provider_name and meta:
|
||||
provider_name = meta.id
|
||||
elif not provider_name:
|
||||
provider_name = "Unknown Provider"
|
||||
provider_capability_type = meta.provider_type
|
||||
|
||||
status_info = {
|
||||
"id": getattr(meta, "id", "Unknown ID"),
|
||||
"model": getattr(meta, "model", "Unknown Model"),
|
||||
"type": getattr(meta, "type", "Unknown Type"),
|
||||
"type": provider_capability_type.value,
|
||||
"name": provider_name,
|
||||
"status": "unavailable", # 默认为不可用
|
||||
"error": None,
|
||||
@@ -203,92 +203,194 @@ class ConfigRoute(Route):
|
||||
logger.debug(
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if hasattr(response, "completion_text") and response.completion_text:
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
|
||||
if provider_capability_type == ProviderType.CHAT_COMPLETION:
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
else:
|
||||
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
|
||||
logger.debug(
|
||||
f"Received response from {status_info['name']}: {response}"
|
||||
)
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if (
|
||||
hasattr(response, "completion_text")
|
||||
and response.completion_text
|
||||
):
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
)
|
||||
else:
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
|
||||
)
|
||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
||||
try:
|
||||
# For embedding, we can call the get_embedding method with a short prompt.
|
||||
embedding_result = await provider.get_embedding("health_check")
|
||||
if isinstance(embedding_result, list) and (
|
||||
not embedding_result or isinstance(embedding_result[0], float)
|
||||
):
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing embedding provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: {str(e)}"
|
||||
|
||||
elif provider_capability_type == ProviderType.TEXT_TO_SPEECH:
|
||||
try:
|
||||
# For TTS, we can call the get_audio method with a short prompt.
|
||||
audio_result = await provider.get_audio("你好")
|
||||
if isinstance(audio_result, str) and audio_result:
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"TTS test failed: unexpected result type {type(audio_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing TTS provider {provider_name}: {e}", exc_info=True
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: {str(e)}"
|
||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Sending health check audio to provider: {status_info['name']}"
|
||||
)
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(), "samples", "stt_health_check.wav"
|
||||
)
|
||||
if not os.path.exists(sample_audio_path):
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
"STT test failed: sample audio file not found."
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}"
|
||||
)
|
||||
else:
|
||||
text_result = await provider.get_text(sample_audio_path)
|
||||
if isinstance(text_result, str) and text_result:
|
||||
status_info["status"] = "available"
|
||||
snippet = (
|
||||
text_result[:70] + "..."
|
||||
if len(text_result) > 70
|
||||
else text_result
|
||||
)
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'"
|
||||
)
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
f"STT test failed: unexpected result type {type(text_result)}"
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error testing STT provider {provider_name}: {e}", exc_info=True
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: {str(e)}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
|
||||
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}"
|
||||
)
|
||||
status_info["status"] = "available"
|
||||
status_info["error"] = (
|
||||
"This provider type is not tested and is assumed to be available."
|
||||
)
|
||||
|
||||
return status_info
|
||||
|
||||
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
|
||||
def _error_response(
|
||||
self, message: str, status_code: int = 500, log_fn=logger.error
|
||||
):
|
||||
log_fn(message)
|
||||
# 记录更详细的traceback信息,但只在是严重错误时
|
||||
if status_code == 500:
|
||||
log_fn(traceback.format_exc())
|
||||
return Response().error(message, status_code=status_code).__dict__
|
||||
return Response().error(message).__dict__
|
||||
|
||||
async def check_one_provider_status(self):
|
||||
"""API: check a single LLM Provider's status by id"""
|
||||
provider_id = request.args.get("id")
|
||||
if not provider_id:
|
||||
return self._error_response("Missing provider_id parameter", 400, logger.warning)
|
||||
return self._error_response(
|
||||
"Missing provider_id parameter", 400, logger.warning
|
||||
)
|
||||
|
||||
logger.info(f"API call: /config/provider/check_one id={provider_id}")
|
||||
try:
|
||||
all_providers = self.core_lifecycle.star_context.get_all_providers()
|
||||
# replace manual loop with next(filter(...))
|
||||
target = next(
|
||||
(p for p in all_providers if p.provider_config.get("id") == provider_id),
|
||||
None
|
||||
)
|
||||
prov_mgr = self.core_lifecycle.provider_manager
|
||||
target = prov_mgr.inst_map.get(provider_id)
|
||||
|
||||
if not target:
|
||||
return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning)
|
||||
logger.warning(
|
||||
f"Provider with id '{provider_id}' not found in provider_manager."
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.error(f"Provider with id '{provider_id}' not found")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
result = await self._test_single_provider(target)
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
except Exception as e:
|
||||
return self._error_response(
|
||||
f"Critical error checking provider {provider_id}: {e}",
|
||||
500
|
||||
f"Critical error checking provider {provider_id}: {e}", 500
|
||||
)
|
||||
|
||||
async def get_configs(self):
|
||||
|
||||
@@ -166,7 +166,7 @@ class ConversationRoute(Route):
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
return Response().ok({"message": "对话删除成功"}).__dict__
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
from quart import make_response
|
||||
from astrbot.core import logger, LogBroker
|
||||
from .route import Route, RouteContext
|
||||
from .route import Route, RouteContext, Response
|
||||
|
||||
|
||||
class LogRoute(Route):
|
||||
@@ -10,6 +10,9 @@ class LogRoute(Route):
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||
self.app.add_url_rule(
|
||||
"/api/log-history", view_func=self.log_history, methods=["GET"]
|
||||
)
|
||||
|
||||
async def log(self):
|
||||
async def stream():
|
||||
@@ -23,7 +26,6 @@ class LogRoute(Route):
|
||||
**message, # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.07) # 控制发送频率,避免过快
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
@@ -43,3 +45,20 @@ class LogRoute(Route):
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
async def log_history(self):
|
||||
"""获取日志历史"""
|
||||
try:
|
||||
logs = list(self.log_broker.log_cache)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"logs": logs,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import traceback
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import ssl
|
||||
import certifi
|
||||
@@ -18,12 +20,6 @@ from astrbot.core.star.filter.regex import RegexFilter
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core import DEMO_MODE
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(
|
||||
@@ -81,15 +77,33 @@ class PluginRoute(Route):
|
||||
|
||||
async def get_online_plugins(self):
|
||||
custom = request.args.get("custom_registry")
|
||||
force_refresh = request.args.get("force_refresh", "false").lower() == "true"
|
||||
|
||||
cache_file = "data/plugins.json"
|
||||
|
||||
if custom:
|
||||
urls = [custom]
|
||||
else:
|
||||
urls = ["https://api.soulter.top/astrbot/plugins"]
|
||||
urls = [
|
||||
"https://api.soulter.top/astrbot/plugins",
|
||||
"https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json",
|
||||
]
|
||||
|
||||
# 新增:创建 SSL 上下文,使用 certifi 提供的根证书
|
||||
# 如果不是强制刷新,先检查缓存是否有效
|
||||
cached_data = None
|
||||
if not force_refresh:
|
||||
# 先检查MD5是否匹配,如果匹配则使用缓存
|
||||
if await self._is_cache_valid(cache_file):
|
||||
cached_data = self._load_plugin_cache(cache_file)
|
||||
if cached_data:
|
||||
logger.debug("缓存MD5匹配,使用缓存的插件市场数据")
|
||||
return Response().ok(cached_data).__dict__
|
||||
|
||||
# 尝试获取远程数据
|
||||
remote_data = None
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
@@ -97,14 +111,123 @@ class PluginRoute(Route):
|
||||
) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
remote_data = await response.json()
|
||||
|
||||
# 检查远程数据是否为空
|
||||
if not remote_data or (
|
||||
isinstance(remote_data, dict) and len(remote_data) == 0
|
||||
):
|
||||
logger.warning(f"远程插件市场数据为空: {url}")
|
||||
continue # 继续尝试其他URL或使用缓存
|
||||
|
||||
logger.info("成功获取远程插件市场数据")
|
||||
# 获取最新的MD5并保存到缓存
|
||||
current_md5 = await self._get_remote_md5()
|
||||
self._save_plugin_cache(
|
||||
cache_file, remote_data, current_md5
|
||||
)
|
||||
return Response().ok(remote_data).__dict__
|
||||
else:
|
||||
logger.error(f"请求 {url} 失败,状态码:{response.status}")
|
||||
except Exception as e:
|
||||
logger.error(f"请求 {url} 失败,错误:{e}")
|
||||
|
||||
return Response().error("获取插件列表失败").__dict__
|
||||
# 如果远程获取失败,尝试使用缓存数据
|
||||
if not cached_data:
|
||||
cached_data = self._load_plugin_cache(cache_file)
|
||||
|
||||
if cached_data:
|
||||
logger.warning("远程插件市场数据获取失败,使用缓存数据")
|
||||
return Response().ok(cached_data, "使用缓存数据,可能不是最新版本").__dict__
|
||||
|
||||
return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__
|
||||
|
||||
async def _is_cache_valid(self, cache_file: str) -> bool:
|
||||
"""检查缓存是否有效(基于MD5)"""
|
||||
try:
|
||||
if not os.path.exists(cache_file):
|
||||
return False
|
||||
|
||||
# 加载缓存文件
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
cached_md5 = cache_data.get("md5")
|
||||
if not cached_md5:
|
||||
logger.debug("缓存文件中没有MD5信息")
|
||||
return False
|
||||
|
||||
# 获取远程MD5
|
||||
remote_md5 = await self._get_remote_md5()
|
||||
if not remote_md5:
|
||||
logger.warning("无法获取远程MD5,将使用缓存")
|
||||
return True # 如果无法获取远程MD5,认为缓存有效
|
||||
|
||||
is_valid = cached_md5 == remote_md5
|
||||
logger.debug(
|
||||
f"插件数据MD5: 本地={cached_md5}, 远程={remote_md5}, 有效={is_valid}"
|
||||
)
|
||||
return is_valid
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检查缓存有效性失败: {e}")
|
||||
return False
|
||||
|
||||
async def _get_remote_md5(self) -> str:
|
||||
"""获取远程插件数据的MD5"""
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, connector=connector
|
||||
) as session:
|
||||
async with session.get(
|
||||
"https://api.soulter.top/astrbot/plugins-md5"
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("md5", "")
|
||||
else:
|
||||
logger.error(f"获取MD5失败,状态码:{response.status}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"获取远程MD5失败: {e}")
|
||||
return ""
|
||||
|
||||
def _load_plugin_cache(self, cache_file: str):
|
||||
"""加载本地缓存的插件市场数据"""
|
||||
try:
|
||||
if os.path.exists(cache_file):
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cache_data = json.load(f)
|
||||
# 检查缓存是否有效
|
||||
if "data" in cache_data and "timestamp" in cache_data:
|
||||
logger.debug(
|
||||
f"加载缓存文件: {cache_file}, 缓存时间: {cache_data['timestamp']}"
|
||||
)
|
||||
return cache_data["data"]
|
||||
except Exception as e:
|
||||
logger.warning(f"加载插件市场缓存失败: {e}")
|
||||
return None
|
||||
|
||||
def _save_plugin_cache(self, cache_file: str, data, md5: str = None):
|
||||
"""保存插件市场数据到本地缓存"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
||||
|
||||
cache_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data,
|
||||
"md5": md5 or "",
|
||||
}
|
||||
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"插件市场数据已缓存到: {cache_file}, MD5: {md5}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存插件市场缓存失败: {e}")
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
@@ -332,9 +455,6 @@ class PluginRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_plugin_readme(self):
|
||||
if not nh3:
|
||||
return Response().error("未安装 nh3 库").__dict__
|
||||
|
||||
plugin_name = request.args.get("name")
|
||||
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
|
||||
|
||||
@@ -370,11 +490,9 @@ class PluginRoute(Route):
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"content": cleaned_content}, "成功获取README内容")
|
||||
.ok({"content": readme_content}, "成功获取README内容")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -395,12 +513,14 @@ class PluginRoute(Route):
|
||||
platform_type = platform.get("type", "")
|
||||
platform_id = platform.get("id", "")
|
||||
|
||||
platforms.append({
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
})
|
||||
platforms.append(
|
||||
{
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
}
|
||||
)
|
||||
|
||||
adjusted_platform_enable = {}
|
||||
for platform_id, plugins in platform_enable.items():
|
||||
@@ -409,11 +529,13 @@ class PluginRoute(Route):
|
||||
# 获取所有插件,包括系统内部插件
|
||||
plugins = []
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
plugins.append({
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
})
|
||||
plugins.append(
|
||||
{
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
|
||||
@@ -421,11 +543,13 @@ class PluginRoute(Route):
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
})
|
||||
.ok(
|
||||
{
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
673
astrbot/dashboard/routes/session_management.py
Normal file
673
astrbot/dashboard/routes/session_management.py
Normal file
@@ -0,0 +1,673 @@
|
||||
import traceback
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class SessionManagementRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db_helper: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/session/list": ("GET", self.list_sessions),
|
||||
"/session/update_persona": ("POST", self.update_session_persona),
|
||||
"/session/update_provider": ("POST", self.update_session_provider),
|
||||
"/session/get_session_info": ("POST", self.get_session_info),
|
||||
"/session/plugins": ("GET", self.get_session_plugins),
|
||||
"/session/update_plugin": ("POST", self.update_session_plugin),
|
||||
"/session/update_llm": ("POST", self.update_session_llm),
|
||||
"/session/update_tts": ("POST", self.update_session_tts),
|
||||
"/session/update_name": ("POST", self.update_session_name),
|
||||
"/session/update_status": ("POST", self.update_session_status),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
# 获取会话对话映射
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
|
||||
# 获取会话提供商偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
|
||||
# 获取可用的 personas
|
||||
personas = self.core_lifecycle.star_context.provider_manager.personas
|
||||
|
||||
# 获取可用的 providers
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
sessions = []
|
||||
|
||||
# 构建会话信息
|
||||
for session_id, conversation_id in session_conversations.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"session_enabled": SessionServiceManager.is_session_enabled(
|
||||
session_id
|
||||
),
|
||||
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"tts_enabled": SessionServiceManager.is_tts_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"platform": session_id.split(":")[0]
|
||||
if ":" in session_id
|
||||
else "unknown",
|
||||
"message_type": session_id.split(":")[1]
|
||||
if session_id.count(":") >= 1
|
||||
else "unknown",
|
||||
"session_name": SessionServiceManager.get_session_display_name(
|
||||
session_id
|
||||
),
|
||||
"session_raw_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
}
|
||||
|
||||
# 获取对话信息
|
||||
conversation = self.db_helper.get_conversation_by_user_id(
|
||||
session_id, conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
# 查找 persona 名称
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = (
|
||||
default_stt_provider.meta().id
|
||||
)
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = (
|
||||
default_tts_provider.meta().id
|
||||
)
|
||||
|
||||
sessions.append(session_info)
|
||||
|
||||
# 获取可用的 personas 和 providers 列表
|
||||
available_personas = [
|
||||
{"name": p["name"], "prompt": p.get("prompt", "")} for p in personas
|
||||
]
|
||||
|
||||
available_chat_providers = []
|
||||
for provider in provider_manager.provider_insts:
|
||||
meta = provider.meta()
|
||||
available_chat_providers.append(
|
||||
{
|
||||
"id": meta.id,
|
||||
"name": meta.id,
|
||||
"model": meta.model,
|
||||
"type": meta.type,
|
||||
}
|
||||
)
|
||||
|
||||
available_stt_providers = []
|
||||
for provider in provider_manager.stt_provider_insts:
|
||||
meta = provider.meta()
|
||||
available_stt_providers.append(
|
||||
{
|
||||
"id": meta.id,
|
||||
"name": meta.id,
|
||||
"model": meta.model,
|
||||
"type": meta.type,
|
||||
}
|
||||
)
|
||||
|
||||
available_tts_providers = []
|
||||
for provider in provider_manager.tts_provider_insts:
|
||||
meta = provider.meta()
|
||||
available_tts_providers.append(
|
||||
{
|
||||
"id": meta.id,
|
||||
"name": meta.id,
|
||||
"model": meta.model,
|
||||
"type": meta.type,
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
"sessions": sessions,
|
||||
"available_personas": available_personas,
|
||||
"available_chat_providers": available_chat_providers,
|
||||
"available_stt_providers": available_stt_providers,
|
||||
"available_tts_providers": available_tts_providers,
|
||||
}
|
||||
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话列表失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_persona(self):
|
||||
"""更新指定会话的 persona"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
persona_name = data.get("persona_name")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if persona_name is None:
|
||||
return Response().error("缺少必要参数: persona_name").__dict__
|
||||
|
||||
# 获取会话当前的对话 ID
|
||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
|
||||
if not conversation_id:
|
||||
# 如果没有对话,创建一个新的对话
|
||||
conversation_id = await conversation_manager.new_conversation(
|
||||
session_id
|
||||
)
|
||||
|
||||
# 更新 persona
|
||||
await conversation_manager.update_conversation_persona_id(
|
||||
session_id, persona_name
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_provider(self):
|
||||
"""更新指定会话的 provider"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
provider_id = data.get("provider_id")
|
||||
# "chat_completion", "speech_to_text", "text_to_speech"
|
||||
provider_type = data.get("provider_type")
|
||||
|
||||
if not session_id or not provider_id or not provider_type:
|
||||
return (
|
||||
Response()
|
||||
.error("缺少必要参数: session_id, provider_id, provider_type")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 转换 provider_type 字符串为枚举
|
||||
if provider_type == "chat_completion":
|
||||
provider_type_enum = ProviderType.CHAT_COMPLETION
|
||||
elif provider_type == "speech_to_text":
|
||||
provider_type_enum = ProviderType.SPEECH_TO_TEXT
|
||||
elif provider_type == "text_to_speech":
|
||||
provider_type_enum = ProviderType.TEXT_TO_SPEECH
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"不支持的 provider_type: {provider_type}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 设置 provider
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_enum,
|
||||
umo=session_id,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}"
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话提供商失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_info(self):
|
||||
"""获取指定会话的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
# 获取会话对话信息
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
conversation_id = session_conversations.get(session_id)
|
||||
|
||||
if not conversation_id:
|
||||
return Response().error(f"会话 {session_id} 未找到对话").__dict__
|
||||
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"tts_enabled": None, # 将在下面设置
|
||||
"platform": session_id.split(":")[0]
|
||||
if ":" in session_id
|
||||
else "unknown",
|
||||
"message_type": session_id.split(":")[1]
|
||||
if session_id.count(":") >= 1
|
||||
else "unknown",
|
||||
"session_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
}
|
||||
|
||||
# 获取TTS状态
|
||||
session_info["tts_enabled"] = (
|
||||
SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
)
|
||||
|
||||
# 获取对话信息
|
||||
conversation = self.db_helper.get_conversation_by_user_id(
|
||||
session_id, conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
personas = provider_manager.personas
|
||||
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = default_stt_provider.meta().id
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = default_tts_provider.meta().id
|
||||
|
||||
return Response().ok(session_info).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话信息失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_plugins(self):
|
||||
"""获取指定会话的插件配置信息"""
|
||||
try:
|
||||
session_id = request.args.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 获取所有已激活的插件
|
||||
all_plugins = []
|
||||
plugin_manager = self.core_lifecycle.plugin_manager
|
||||
|
||||
for plugin in plugin_manager.context.get_all_stars():
|
||||
# 只显示已激活的插件,不包括保留插件
|
||||
if plugin.activated and not plugin.reserved:
|
||||
plugin_name = plugin.name or ""
|
||||
plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin_name
|
||||
)
|
||||
|
||||
all_plugins.append(
|
||||
{
|
||||
"name": plugin_name,
|
||||
"author": plugin.author,
|
||||
"desc": plugin.desc,
|
||||
"enabled": plugin_enabled,
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"plugins": all_plugins,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话插件配置失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话插件配置失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_plugin(self):
|
||||
"""更新指定会话的插件启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
plugin_name = data.get("plugin_name")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if not plugin_name:
|
||||
return Response().error("缺少必要参数: plugin_name").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 验证插件是否存在且已激活
|
||||
plugin_manager = self.core_lifecycle.plugin_manager
|
||||
plugin = plugin_manager.context.get_registered_star(plugin_name)
|
||||
|
||||
if not plugin:
|
||||
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
||||
|
||||
if not plugin.activated:
|
||||
return Response().error(f"插件 {plugin_name} 未激活").__dict__
|
||||
|
||||
if plugin.reserved:
|
||||
return (
|
||||
Response()
|
||||
.error(f"插件 {plugin_name} 是系统保留插件,无法管理")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 使用 SessionPluginManager 更新插件状态
|
||||
SessionPluginManager.set_plugin_status_for_session(
|
||||
session_id, plugin_name, enabled
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"插件 {plugin_name} 已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"plugin_name": plugin_name,
|
||||
"enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话插件状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_llm(self):
|
||||
"""更新指定会话的LLM启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新LLM状态
|
||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_tts(self):
|
||||
"""更新指定会话的TTS启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新TTS状态
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"TTS已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"tts_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话TTS状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_name(self):
|
||||
"""更新指定会话的自定义名称"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
custom_name = data.get("custom_name", "")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新会话名称
|
||||
SessionServiceManager.set_session_custom_name(session_id, custom_name)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"会话名称已更新为: {custom_name if custom_name.strip() else '已清除自定义名称'}",
|
||||
"session_id": session_id,
|
||||
"custom_name": custom_name,
|
||||
"display_name": SessionServiceManager.get_session_display_name(
|
||||
session_id
|
||||
),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话名称失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话名称失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_status(self):
|
||||
"""更新指定会话的整体启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
session_enabled = data.get("session_enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if session_enabled is None:
|
||||
return Response().error("缺少必要参数: session_enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新会话整体状态
|
||||
SessionServiceManager.set_session_status(session_id, session_enabled)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"会话整体状态已更新为: {'启用' if session_enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"session_enabled": session_enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
|
||||
@@ -2,6 +2,7 @@ import traceback
|
||||
import psutil
|
||||
import time
|
||||
import threading
|
||||
import aiohttp
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
@@ -25,6 +26,7 @@ class StatRoute(Route):
|
||||
"/stat/version": ("GET", self.get_version),
|
||||
"/stat/start-time": ("GET", self.get_start_time),
|
||||
"/stat/restart-core": ("POST", self.restart_core),
|
||||
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.register_routes()
|
||||
@@ -45,11 +47,7 @@ class StatRoute(Route):
|
||||
"""将总秒数转换为时分秒组件"""
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return {
|
||||
"hours": hours,
|
||||
"minutes": minutes,
|
||||
"seconds": seconds
|
||||
}
|
||||
return {"hours": hours, "minutes": minutes, "seconds": seconds}
|
||||
|
||||
def is_default_cred(self):
|
||||
username = self.config["dashboard"]["username"]
|
||||
@@ -144,3 +142,40 @@ class StatRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def test_ghproxy_connection(self):
|
||||
"""
|
||||
测试 GitHub 代理连接是否可用。
|
||||
"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
proxy_url: str = data.get("proxy_url")
|
||||
|
||||
if not proxy_url:
|
||||
return Response().error("proxy_url is required").__dict__
|
||||
|
||||
proxy_url = proxy_url.rstrip("/")
|
||||
|
||||
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
|
||||
start_time = time.time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
test_url, timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
end_time = time.time()
|
||||
_ = await response.text()
|
||||
ret = {
|
||||
"latency": round((end_time - start_time) * 1000, 2),
|
||||
}
|
||||
return Response().ok(data=ret).__dict__
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed. Status code: {response.status}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {str(e)}").__dict__
|
||||
|
||||
@@ -26,6 +26,7 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/update": ("POST", self.update_mcp_server),
|
||||
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
|
||||
"/tools/mcp/market": ("GET", self.get_mcp_markets),
|
||||
"/tools/mcp/test": ("POST", self.test_mcp_connection),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
@@ -132,12 +133,19 @@ class ToolsRoute(Route):
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 动态初始化新MCP客户端
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, server_config, timeout=30
|
||||
)
|
||||
except TimeoutError:
|
||||
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
return Response().error("保存配置失败").__dict__
|
||||
@@ -193,31 +201,55 @@ class ToolsRoute(Route):
|
||||
if self.save_mcp_config(config):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
# 如果要激活服务器或者配置已更改
|
||||
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
else:
|
||||
# 客户端不存在,初始化
|
||||
await self.tool_mgr.mcp_service_queue.put({
|
||||
"type": "init",
|
||||
"name": name,
|
||||
"cfg": config["mcpServers"][name],
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError as e:
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, config["mcpServers"][name], timeout=30
|
||||
)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
# 如果要停用服务器
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
self.tool_mgr.mcp_service_queue.put_nowait({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 超时。")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
@@ -239,17 +271,23 @@ class ToolsRoute(Route):
|
||||
if name not in config["mcpServers"]:
|
||||
return Response().error(f"服务器 {name} 不存在").__dict__
|
||||
|
||||
# 删除服务器配置
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 关闭并删除MCP客户端
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
self.tool_mgr.mcp_service_queue.put_nowait({
|
||||
"type": "terminate",
|
||||
"name": name,
|
||||
})
|
||||
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
|
||||
else:
|
||||
return Response().error("保存配置失败").__dict__
|
||||
@@ -281,3 +319,20 @@ class ToolsRoute(Route):
|
||||
except Exception as _:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error("获取市场数据失败").__dict__
|
||||
|
||||
async def test_mcp_connection(self):
|
||||
"""
|
||||
测试 MCP 服务器连接
|
||||
"""
|
||||
try:
|
||||
server_data = await request.json
|
||||
config = server_data.get("mcp_server_config", None)
|
||||
|
||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||
return (
|
||||
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import logging
|
||||
import jwt
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from astrbot.core.config.default import VERSION
|
||||
from quart import Quart, request, jsonify, g
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
from .routes.route import RouteContext, Response
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
|
||||
APP: Quart = None
|
||||
|
||||
@@ -53,6 +57,9 @@ class AstrBotDashboard:
|
||||
self.tools_root = ToolsRoute(self.context, core_lifecycle)
|
||||
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
|
||||
self.file_route = FileRoute(self.context)
|
||||
self.session_management_route = SessionManagementRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
6
changelogs/v3.5.20.md
Normal file
6
changelogs/v3.5.20.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复: 工具调用的结果错误地被当作消息发送
|
||||
2. 新增: 支持对引用消息中的图片进行理解(QQ, Telegram)
|
||||
3. 优化: QQ 主动消息发送逻辑,优化合并消息、文件、语音、图片等的处理
|
||||
4. 优化: 移除插件的 @register 插件注册装饰器(插件只需要继承 Star 类即可,AstrBot 会自动处理),简化插件代码开发
|
||||
7
changelogs/v3.5.21.md
Normal file
7
changelogs/v3.5.21.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复: WebChat 下图片、音频消息没有被正确渲染
|
||||
2. 修复: 部分情况下,插件信息无法正确显示
|
||||
3. 修复: WebChat 下开启分段回复后,消息错位
|
||||
4. 优化: 提高插件加载的性能和稳定性
|
||||
5. 修复: WebUI 对话数据库页中,无法真正删除对话
|
||||
3
changelogs/v3.5.22.md
Normal file
3
changelogs/v3.5.22.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复: 用户环境没有 Docker 时,可能导致死锁(表现为在初始化 AstrBot 的时候卡住)
|
||||
18
changelogs/v3.5.23.md
Normal file
18
changelogs/v3.5.23.md
Normal file
@@ -0,0 +1,18 @@
|
||||
1. 改进: WebUI提供商徽标显示
|
||||
2. 修复:在LLMRequestSubStage中添加对提供商请求处理的调试日志记录
|
||||
3. 修复: 为嵌入模型提供商添加状态检查
|
||||
4. 新增: 支持在WebUI上管理会话
|
||||
5. 新增: 为ProviderMetadata添加provider_type字段并优化提供商可用性测试
|
||||
6. 改进: WebUI聊天页面Markdown代码块
|
||||
7. 修复: 讯飞模型工具使用错误
|
||||
8. 修复: 修复mcp导致的持续占用100% CPU
|
||||
9. 重构: mcp服务器重载机制
|
||||
10. 新增: 为WebChat页面添加文件上传按钮
|
||||
11. 优化: 工具使用页面用户界面
|
||||
12. 新增: 添加测试GitHub加速地址的组件
|
||||
13. 新增: 使用会话锁保证分段回复时的消息发送顺序
|
||||
14. 新增: 实现日志历史记录检索并改进日志流处理
|
||||
15. 杂务: 修改openai的嵌入模型默认维度为1024
|
||||
16. 修复:更新axios版本范围
|
||||
17. chore: remove adapters of WeChat personal account(gewechat)
|
||||
18. 新增: 为AstrBotConfig中的嵌套对象添加展开状态管理
|
||||
10
changelogs/v3.5.24.md
Normal file
10
changelogs/v3.5.24.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# What's Changed
|
||||
|
||||
> 新版本预告: v4.0.0 即将发布。
|
||||
|
||||
1. 新增: 添加对 ModelScope、Compshare(优云智算)的模版支持。
|
||||
2. 优化: 增加插件数据缓存,优化插件市场数据获取时的稳定性。
|
||||
|
||||
其他更新:
|
||||
|
||||
1. 现已支持在 1Panel 平台通过应用商城快捷部署 AstrBot。详见:[在 1Panel 部署 AstrBot](https://docs.astrbot.app/deploy/astrbot/1panel.html)
|
||||
@@ -17,7 +17,7 @@
|
||||
"@tiptap/starter-kit": "2.1.7",
|
||||
"@tiptap/vue-3": "2.1.7",
|
||||
"apexcharts": "3.42.0",
|
||||
"axios": "^1.6.2",
|
||||
"axios": ">=1.6.2 <1.10.0 || >1.10.0 <2.0.0",
|
||||
"axios-mock-adapter": "^1.22.0",
|
||||
"chance": "1.1.11",
|
||||
"d3": "^7.9.0",
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
<script setup>
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref } from 'vue'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
defineProps({
|
||||
const props = defineProps({
|
||||
metadata: {
|
||||
type: Object,
|
||||
required: true
|
||||
@@ -16,11 +16,21 @@ defineProps({
|
||||
metadataKey: {
|
||||
type: String,
|
||||
required: true
|
||||
},
|
||||
isEditing: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
})
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const filteredIterable = computed(() => {
|
||||
if (!props.iterable) return {}
|
||||
const { hint, ...rest } = props.iterable
|
||||
return rest
|
||||
})
|
||||
|
||||
const dialog = ref(false)
|
||||
const currentEditingKey = ref('')
|
||||
const currentEditingLanguage = ref('json')
|
||||
@@ -54,7 +64,19 @@ function saveEditedContent() {
|
||||
<v-card-text class="px-0 py-1">
|
||||
<!-- Object Type Configuration -->
|
||||
<div v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template" class="object-config">
|
||||
<div v-for="(val, key, index) in iterable" :key="key" class="config-item">
|
||||
<!-- Provider-level hint -->
|
||||
<v-alert
|
||||
v-if="iterable.hint && !isEditing"
|
||||
type="info"
|
||||
variant="tonal"
|
||||
class="mb-4"
|
||||
border="start"
|
||||
density="compact"
|
||||
>
|
||||
{{ iterable.hint }}
|
||||
</v-alert>
|
||||
|
||||
<div v-for="(val, key, index) in filteredIterable" :key="key" class="config-item">
|
||||
<!-- Nested Object -->
|
||||
<div v-if="metadata[metadataKey].items[key]?.type === 'object'" class="nested-object">
|
||||
<div v-if="metadata[metadataKey].items[key] && !metadata[metadataKey].items[key]?.invisible" class="nested-container">
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
hide-details
|
||||
density="compact"
|
||||
:model-value="getItemEnabled()"
|
||||
:loading="loading"
|
||||
:disabled="loading"
|
||||
v-bind="props"
|
||||
@update:model-value="toggleEnabled"
|
||||
></v-switch>
|
||||
@@ -47,7 +49,6 @@
|
||||
contain
|
||||
width="120"
|
||||
height="120"
|
||||
class="rounded-circle"
|
||||
></v-img>
|
||||
</div>
|
||||
</v-card>
|
||||
@@ -78,6 +79,10 @@ export default {
|
||||
bglogo: {
|
||||
type: String,
|
||||
default: null
|
||||
},
|
||||
loading: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
emits: ['toggle-enabled', 'delete', 'edit'],
|
||||
|
||||
152
dashboard/src/components/shared/ProxySelector.vue
Normal file
152
dashboard/src/components/shared/ProxySelector.vue
Normal file
@@ -0,0 +1,152 @@
|
||||
<template>
|
||||
<h5>GitHub 加速</h5>
|
||||
<v-radio-group class="mt-2" v-model="radioValue" hide-details="true">
|
||||
<v-radio label="不使用 GitHub 加速" value="0"></v-radio>
|
||||
<v-radio value="1">
|
||||
<template v-slot:label>
|
||||
<span>使用 GitHub 加速</span>
|
||||
<v-btn v-if="radioValue === '1'" class="ml-2" @click="testAllProxies" size="x-small"
|
||||
variant="tonal" :loading="loadingTestingConnection">
|
||||
测试代理连通性
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
<div v-if="radioValue === '1'" style="margin-left: 16px;">
|
||||
<v-radio-group v-model="githubProxyRadioControl" class="mt-2" hide-details="true">
|
||||
<v-radio color="success" v-for="(proxy, idx) in githubProxies" :key="proxy" :value="idx">
|
||||
<template v-slot:label>
|
||||
<div class="d-flex align-center">
|
||||
<span class="mr-2">{{ proxy }}</span>
|
||||
<div v-if="proxyStatus[idx]">
|
||||
<v-chip
|
||||
:color="proxyStatus[idx].available ? 'success' : 'error'"
|
||||
size="x-small"
|
||||
class="mr-1">
|
||||
{{ proxyStatus[idx].available ? '可用' : '不可用' }}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-if="proxyStatus[idx].available"
|
||||
color="info"
|
||||
size="x-small">
|
||||
{{ proxyStatus[idx].latency }}ms
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</v-radio>
|
||||
<v-radio color="primary" value="-1" label="自定义">
|
||||
<template v-slot:label v-if="githubProxyRadioControl === '-1'">
|
||||
<v-text-field density="compact" v-model="selectedGitHubProxy" variant="outlined"
|
||||
style="width: 100vw;" placeholder="自定义" hide-details="true">
|
||||
</v-text-field>
|
||||
</template>
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/settings');
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
githubProxies: [
|
||||
"https://edgeone.gh-proxy.com",
|
||||
"https://hk.gh-proxy.com/",
|
||||
"https://gh-proxy.com/",
|
||||
"https://gh.llkk.cc",
|
||||
],
|
||||
githubProxyRadioControl: "0", // the index of the selected proxy
|
||||
selectedGitHubProxy: "",
|
||||
radioValue: "0", // 0: 不使用, 1: 使用
|
||||
loadingTestingConnection: false,
|
||||
testingProxies: {},
|
||||
proxyStatus: {},
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
async testSingleProxy(idx) {
|
||||
this.testingProxies[idx] = true;
|
||||
|
||||
const proxy = this.githubProxies[idx];
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/stat/test-ghproxy-connection', {
|
||||
proxy_url: proxy
|
||||
});
|
||||
console.log(response.data);
|
||||
if (response.status === 200) {
|
||||
this.proxyStatus[idx] = {
|
||||
available: true,
|
||||
latency: Math.round(response.data.data.latency)
|
||||
};
|
||||
} else {
|
||||
this.proxyStatus[idx] = {
|
||||
available: false,
|
||||
latency: 0
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
this.proxyStatus[idx] = {
|
||||
available: false,
|
||||
latency: 0
|
||||
};
|
||||
} finally {
|
||||
this.testingProxies[idx] = false;
|
||||
}
|
||||
},
|
||||
|
||||
async testAllProxies() {
|
||||
this.loadingTestingConnection = true;
|
||||
|
||||
const promises = this.githubProxies.map((proxy, idx) =>
|
||||
this.testSingleProxy(idx)
|
||||
);
|
||||
|
||||
await Promise.all(promises);
|
||||
this.loadingTestingConnection = false;
|
||||
},
|
||||
},
|
||||
mounted() {
|
||||
this.selectedGitHubProxy = localStorage.getItem('selectedGitHubProxy') || "";
|
||||
this.radioValue = localStorage.getItem('githubProxyRadioValue') || "0";
|
||||
this.githubProxyRadioControl = localStorage.getItem('githubProxyRadioControl') || "0";
|
||||
},
|
||||
watch: {
|
||||
selectedGitHubProxy: function (newVal, oldVal) {
|
||||
if (!newVal) {
|
||||
newVal = ""
|
||||
}
|
||||
localStorage.setItem('selectedGitHubProxy', newVal);
|
||||
},
|
||||
radioValue: function (newVal) {
|
||||
localStorage.setItem('githubProxyRadioValue', newVal);
|
||||
if (newVal === "0") {
|
||||
this.selectedGitHubProxy = "";
|
||||
}
|
||||
},
|
||||
githubProxyRadioControl: function (newVal) {
|
||||
localStorage.setItem('githubProxyRadioControl', newVal);
|
||||
if (newVal !== "-1") {
|
||||
this.selectedGitHubProxy = this.githubProxies[newVal] || "";
|
||||
} else {
|
||||
this.selectedGitHubProxy = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
.v-label {
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
</style>
|
||||
@@ -38,6 +38,7 @@ export class I18nLoader {
|
||||
{ name: 'features/chat', path: 'features/chat.json' },
|
||||
{ name: 'features/extension', path: 'features/extension.json' },
|
||||
{ name: 'features/conversation', path: 'features/conversation.json' },
|
||||
{ name: 'features/session-management', path: 'features/session-management.json' },
|
||||
{ name: 'features/tooluse', path: 'features/tool-use.json' },
|
||||
{ name: 'features/provider', path: 'features/provider.json' },
|
||||
{ name: 'features/platform', path: 'features/platform.json' },
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"extensionMarketplace": "Extension Market",
|
||||
"chat": "Chat",
|
||||
"conversation": "Conversations",
|
||||
"sessionManagement": "Session Management",
|
||||
"console": "Console",
|
||||
"alkaid": "Alkaid Lab",
|
||||
"about": "About",
|
||||
|
||||
@@ -32,7 +32,8 @@
|
||||
"cancel": "Cancel",
|
||||
"actions": "Actions",
|
||||
"back": "Back",
|
||||
"selectFile": "Select File"
|
||||
"selectFile": "Select File",
|
||||
"refresh": "Refresh"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "Enabled",
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"addPlatform": "Add Platform Adapter",
|
||||
"connectTitle": "Connect {name}",
|
||||
"viewTutorial": "View Tutorial",
|
||||
"noTemplates": "No platform templates available",
|
||||
"idConflict": {
|
||||
"title": "ID Conflict Warning",
|
||||
"message": "Detected duplicate ID \"{id}\". Please use a new ID.",
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
{
|
||||
"title": "Session Management",
|
||||
"subtitle": "Manage active sessions and configurations",
|
||||
"buttons": {
|
||||
"refresh": "Refresh",
|
||||
"edit": "Edit",
|
||||
"apply": "Apply Batch Settings",
|
||||
"editName": "Edit Session Name",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "Active Sessions",
|
||||
"sessionCount": "sessions",
|
||||
"noActiveSessions": "No active sessions",
|
||||
"noActiveSessionsDesc": "Sessions will appear here when users interact with the bot"
|
||||
},
|
||||
"search": {
|
||||
"placeholder": "Search sessions...",
|
||||
"platformFilter": "Platform Filter"
|
||||
},
|
||||
"table": {
|
||||
"headers": {
|
||||
"sessionStatus": "Session Status",
|
||||
"sessionInfo": "Session Info",
|
||||
"persona": "Persona",
|
||||
"chatProvider": "Chat Provider",
|
||||
"sttProvider": "STT Provider",
|
||||
"ttsProvider": "TTS Provider",
|
||||
"llmStatus": "LLM Status",
|
||||
"ttsStatus": "TTS Status",
|
||||
"pluginManagement": "Plugin Management"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled"
|
||||
},
|
||||
"persona": {
|
||||
"none": "No Persona"
|
||||
},
|
||||
"batchOperations": {
|
||||
"title": "Batch Operations",
|
||||
"setPersona": "Batch Set Persona",
|
||||
"setChatProvider": "Batch Set Chat Provider",
|
||||
"setSttProvider": "Batch Set STT Provider",
|
||||
"setTtsProvider": "Batch Set TTS Provider",
|
||||
"setLlmStatus": "Batch Set LLM Status",
|
||||
"setTtsStatus": "Batch Set TTS Status",
|
||||
"noSttProvider": "No STT Provider Available",
|
||||
"noTtsProvider": "No TTS Provider Available"
|
||||
},
|
||||
"pluginManagement": {
|
||||
"title": "Plugin Management",
|
||||
"noPlugins": "No available plugins",
|
||||
"noPluginsDesc": "Currently no active plugins",
|
||||
"loading": "Loading plugin list...",
|
||||
"author": "Author"
|
||||
},
|
||||
"nameEditor": {
|
||||
"title": "Edit Session Name",
|
||||
"customName": "Custom Name",
|
||||
"placeholder": "Enter custom session name (leave empty to use original name)",
|
||||
"originalName": "Original Name",
|
||||
"fullSessionId": "Full Session ID",
|
||||
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "Session list refreshed",
|
||||
"personaUpdateSuccess": "Persona updated successfully",
|
||||
"personaUpdateError": "Failed to update persona",
|
||||
"providerUpdateSuccess": "Provider updated successfully",
|
||||
"providerUpdateError": "Failed to update provider",
|
||||
"sessionStatusSuccess": "Session {status}",
|
||||
"llmStatusSuccess": "LLM {status}",
|
||||
"ttsStatusSuccess": "TTS {status}",
|
||||
"statusUpdateError": "Failed to update status",
|
||||
"loadSessionsError": "Failed to load session list",
|
||||
"batchUpdateSuccess": "Successfully batch updated {count} settings",
|
||||
"batchUpdatePartial": "Batch update completed, {success} successful, {error} failed",
|
||||
"loadPluginsError": "Failed to load plugin list",
|
||||
"pluginStatusSuccess": "Plugin {name} {status}",
|
||||
"pluginStatusError": "Failed to update plugin status",
|
||||
"nameUpdateSuccess": "Session name updated successfully",
|
||||
"nameUpdateError": "Failed to update session name"
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,9 @@
|
||||
"buttons": {
|
||||
"refresh": "Refresh",
|
||||
"add": "Add Server",
|
||||
"useTemplate": "Use Template"
|
||||
"useTemplateStdio": "Stdio Template",
|
||||
"useTemplateStreamableHttp": "Streamable HTTP Template",
|
||||
"useTemplateSse": "SSE Template"
|
||||
},
|
||||
"empty": "No MCP servers available, click Add Server to add one",
|
||||
"status": {
|
||||
@@ -28,8 +30,7 @@
|
||||
"functionTools": {
|
||||
"title": "Function Tools",
|
||||
"buttons": {
|
||||
"expand": "Expand",
|
||||
"collapse": "Collapse"
|
||||
"view": "View Tools"
|
||||
},
|
||||
"search": "Search function tools",
|
||||
"empty": "No function tools available",
|
||||
@@ -68,10 +69,6 @@
|
||||
"enable": "Enable Server",
|
||||
"config": "Server Configuration"
|
||||
},
|
||||
"configNotes": {
|
||||
"note1": "1. Some MCP servers may require filling in `API_KEY` or `TOKEN` information in env according to their requirements, please check if filled.",
|
||||
"note2": "2. When url parameter is specified in configuration: if `transport` parameter is also specified as `streamable_http`, Streamable HTTP is used, otherwise SSE connection is used."
|
||||
},
|
||||
"errors": {
|
||||
"configEmpty": "Configuration cannot be empty",
|
||||
"jsonFormat": "JSON format error: {error}",
|
||||
@@ -79,7 +76,8 @@
|
||||
},
|
||||
"buttons": {
|
||||
"cancel": "Cancel",
|
||||
"save": "Save"
|
||||
"save": "Save",
|
||||
"testConnection": "Test Connection"
|
||||
}
|
||||
},
|
||||
"serverDetail": {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"extensionMarketplace": "插件市场",
|
||||
"chat": "聊天",
|
||||
"conversation": "对话数据库",
|
||||
"sessionManagement": "会话管理",
|
||||
"console": "控制台",
|
||||
"alkaid": "Alkaid",
|
||||
"about": "关于",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user