Compare commits
220 Commits
v3.5.19
...
features/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38486bc5aa | ||
|
|
7b8800c4eb | ||
|
|
8f4625f53b | ||
|
|
1e5f243edb | ||
|
|
e5eab2af34 | ||
|
|
c10973e160 | ||
|
|
b1e4bff3ec | ||
|
|
c1202cda63 | ||
|
|
32d6cd7776 | ||
|
|
2f78d30e93 | ||
|
|
33407c9f0d | ||
|
|
d2d5ef1c5c | ||
|
|
98d8eaee02 | ||
|
|
10b9228060 | ||
|
|
5872f1e017 | ||
|
|
5073f21002 | ||
|
|
69aaf09ac8 | ||
|
|
6e61ee81d8 | ||
|
|
cfd05a8d17 | ||
|
|
29845fcc4c | ||
|
|
e204b180a8 | ||
|
|
563972fd29 | ||
|
|
cbe94b84fc | ||
|
|
aa6f73574d | ||
|
|
94f0419ef7 | ||
|
|
cefd2d7f49 | ||
|
|
81e1e545fb | ||
|
|
d516920e72 | ||
|
|
2171372246 | ||
|
|
d2df4d0cce | ||
|
|
6ab90fc123 | ||
|
|
1a84ebbb1e | ||
|
|
c9c0352369 | ||
|
|
9903b028a3 | ||
|
|
49def5d883 | ||
|
|
6975525b70 | ||
|
|
fbc4f8527b | ||
|
|
90cb5a1951 | ||
|
|
ac71d9f034 | ||
|
|
64bcbc9fc0 | ||
|
|
9e7d46f956 | ||
|
|
e911896cfb | ||
|
|
9c6d66093f | ||
|
|
b2e39b9701 | ||
|
|
e95ad4049b | ||
|
|
1df49d1d6f | ||
|
|
b71000e2f3 | ||
|
|
47e6ed455e | ||
|
|
92592fb9d9 | ||
|
|
02a9769b35 | ||
|
|
7640f11bfc | ||
|
|
be8a0991ed | ||
|
|
9fa44dbcfa | ||
|
|
61aac9c80c | ||
|
|
60af83cfee | ||
|
|
cf64e6c231 | ||
|
|
2cae941bae | ||
|
|
bc0784f41d | ||
|
|
b711140f26 | ||
|
|
c57d75e01a | ||
|
|
1d766001bb | ||
|
|
0759a11a85 | ||
|
|
cb749a38ab | ||
|
|
369eab18ab | ||
|
|
73edeae013 | ||
|
|
7d46314dc8 | ||
|
|
d5a53a89eb | ||
|
|
a85bc510dd | ||
|
|
2beea7d218 | ||
|
|
a93cd3dd5f | ||
|
|
6c1f540170 | ||
|
|
d026a9f009 | ||
|
|
a8e7dadd39 | ||
|
|
2f8d921adf | ||
|
|
0c6e526f94 | ||
|
|
b1e3018b6b | ||
|
|
87f05fce66 | ||
|
|
1b37530c96 | ||
|
|
db4d02c2e2 | ||
|
|
fd7811402b | ||
|
|
eb0325e627 | ||
|
|
842c3c8ea9 | ||
|
|
8b4b04ec09 | ||
|
|
9f32c9280f | ||
|
|
4fcd09cfa8 | ||
|
|
7a8d65d37d | ||
|
|
23129a9ba2 | ||
|
|
7f791e730b | ||
|
|
f7e296b349 | ||
|
|
712d4acaaa | ||
|
|
74a5c01f21 | ||
|
|
3ba8724d77 | ||
|
|
6313a7d8a9 | ||
|
|
432a3f520c | ||
|
|
191b3e42d4 | ||
|
|
a27f05fcb4 | ||
|
|
2f33e0b873 | ||
|
|
f0359467f1 | ||
|
|
d1db8cf2c8 | ||
|
|
b1985ed2ce | ||
|
|
140ddc70e6 | ||
|
|
d7fd616470 | ||
|
|
3ccbef141e | ||
|
|
e92fbb0443 | ||
|
|
bd270aed68 | ||
|
|
28d7864393 | ||
|
|
b5d8173ee3 | ||
|
|
17d62a9af7 | ||
|
|
d89fb863ed | ||
|
|
a21ad77820 | ||
|
|
f86c8e8cab | ||
|
|
cb12cbdd3d | ||
|
|
6661fa996c | ||
|
|
c19bca798b | ||
|
|
8f98b411db | ||
|
|
a8aa03847e | ||
|
|
1bfd747cc6 | ||
|
|
ae06d945a7 | ||
|
|
9f41d5f34d | ||
|
|
ef61c52908 | ||
|
|
d8842ef274 | ||
|
|
c88fdaf353 | ||
|
|
af295da871 | ||
|
|
083235a2fe | ||
|
|
2a3a5f7eb2 | ||
|
|
77c48f280f | ||
|
|
0ee1eb2f9f | ||
|
|
c2b20365bb | ||
|
|
cfdc7e4452 | ||
|
|
2363f61aa9 | ||
|
|
557ac6f9fa | ||
|
|
a49b871cf9 | ||
|
|
a0d6b3efba | ||
|
|
6cabf07bc0 | ||
|
|
a15444ee8c | ||
|
|
ceb5f5669e | ||
|
|
25b75e05e4 | ||
|
|
4d214bb5c1 | ||
|
|
7cbaed8c6c | ||
|
|
2915fdf665 | ||
|
|
a66c385b08 | ||
|
|
4dace7c5d8 | ||
|
|
8ebf087dbf | ||
|
|
2fa8bda5bb | ||
|
|
a5ae833945 | ||
|
|
d21d42b312 | ||
|
|
78575f0f0a | ||
|
|
8ccd292d16 | ||
|
|
2534f59398 | ||
|
|
5c60dbe2b1 | ||
|
|
c99ecde15f | ||
|
|
219f3403d9 | ||
|
|
00f417bad6 | ||
|
|
81649f053b | ||
|
|
e5bde50f2d | ||
|
|
0321e00b0d | ||
|
|
09528e3292 | ||
|
|
e7412a9cbf | ||
|
|
01efe5f869 | ||
|
|
28a178a55c | ||
|
|
88f130014c | ||
|
|
af258c590c | ||
|
|
b0eb5733be | ||
|
|
fe35bfba37 | ||
|
|
7cfbc4ab8f | ||
|
|
7a9d4f0abd | ||
|
|
6f6a5b565c | ||
|
|
e57deb873c | ||
|
|
0f692b1608 | ||
|
|
8c03e79f99 | ||
|
|
71290f0929 | ||
|
|
22364ef7de | ||
|
|
2cc1eb1abc | ||
|
|
90dbcbb4e2 | ||
|
|
66503d58be | ||
|
|
8e10f0ce2b | ||
|
|
f51f510f2e | ||
|
|
c44f085b47 | ||
|
|
a35f36eeaf | ||
|
|
14564c392a | ||
|
|
76e05ea749 | ||
|
|
ab599dceed | ||
|
|
4c37604445 | ||
|
|
bb74018d19 | ||
|
|
575289e5bc | ||
|
|
e89da2a7b4 | ||
|
|
d2f7e55bf5 | ||
|
|
9f31df7f3a | ||
|
|
e1bed60f1f | ||
|
|
edbb856023 | ||
|
|
98d3ab646f | ||
|
|
ab677ea100 | ||
|
|
28a87351f1 | ||
|
|
dcd7dcbbdf | ||
|
|
1538759ba7 | ||
|
|
ec5d71d0e1 | ||
|
|
d121d08d05 | ||
|
|
be08f4a558 | ||
|
|
4df8606ab6 | ||
|
|
71442d26ec | ||
|
|
4f5528869c | ||
|
|
f16feff17b | ||
|
|
d8aae538cd | ||
|
|
31670e75e5 | ||
|
|
ed6011a2be | ||
|
|
cdded38ade | ||
|
|
f536f24833 | ||
|
|
646b18d910 | ||
|
|
e24225c828 | ||
|
|
50a296de20 | ||
|
|
c79e38e044 | ||
|
|
dae745d925 | ||
|
|
791db65526 | ||
|
|
02e2e617f5 | ||
|
|
bfc8024119 | ||
|
|
f26cf6ed6f | ||
|
|
f2be55bd8e | ||
|
|
d241dd17ca | ||
|
|
cecafdfe6c | ||
|
|
6fecfd1a0e |
31
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md
vendored
31
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md
vendored
@@ -1,31 +0,0 @@
|
||||
---
|
||||
name: '🥳 发布插件'
|
||||
title: "[Plugin] 插件名"
|
||||
about: 提交插件到插件市场
|
||||
labels: [ "plugin-publish" ]
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
欢迎发布插件到插件市场!
|
||||
|
||||
## 插件基本信息
|
||||
|
||||
请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "插件名",
|
||||
"desc": "插件介绍",
|
||||
"author": "作者名",
|
||||
"repo": "插件仓库链接",
|
||||
"tags": [],
|
||||
"social_link": ""
|
||||
}
|
||||
```
|
||||
|
||||
## 检查
|
||||
|
||||
- [ ] 我的插件经过完整的测试
|
||||
- [ ] 我的插件不包含恶意代码
|
||||
- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
56
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
Normal file
56
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: 🥳 发布插件
|
||||
description: 提交插件到插件市场
|
||||
title: "[Plugin] 插件名"
|
||||
labels: ["plugin-publish"]
|
||||
assignees: []
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
欢迎发布插件到插件市场!
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## 插件基本信息
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
attributes:
|
||||
label: 插件信息
|
||||
description: 请在下方代码块中填写您的插件信息,确保反引号包裹了JSON
|
||||
value: |
|
||||
```json
|
||||
{
|
||||
"name": "插件名",
|
||||
"desc": "插件介绍",
|
||||
"author": "作者名",
|
||||
"repo": "插件仓库链接",
|
||||
"tags": [],
|
||||
"social_link": ""
|
||||
}
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## 检查
|
||||
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
label: 插件检查清单
|
||||
description: 请确认以下所有项目
|
||||
options:
|
||||
- label: 我的插件经过完整的测试
|
||||
required: true
|
||||
- label: 我的插件不包含恶意代码
|
||||
required: true
|
||||
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
63
.github/copilot-instructions.md
vendored
Normal file
63
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# AstrBot Development Instructions
|
||||
|
||||
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
|
||||
|
||||
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
|
||||
|
||||
## Working Effectively
|
||||
|
||||
### Bootstrap and Install Dependencies
|
||||
- **Python 3.10+ required** - Check `.python-version` file
|
||||
- Install UV package manager: `pip install uv`
|
||||
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
|
||||
- Create required directories: `mkdir -p data/plugins data/config data/temp`
|
||||
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
- Navigate to dashboard: `cd dashboard`
|
||||
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
|
||||
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
|
||||
- Dashboard creates optimized production build in `dashboard/dist/`
|
||||
|
||||
### Testing
|
||||
- Do not generate test files for now.
|
||||
|
||||
### Code Quality and Linting
|
||||
- Install ruff linter: `uv add --dev ruff`
|
||||
- Check code style: `uv run ruff check .` -- takes <1 second
|
||||
- Check formatting: `uv run ruff format --check .` -- takes <1 second
|
||||
- Fix formatting: `uv run ruff format .`
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
### Common Issues and Workarounds
|
||||
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
|
||||
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
|
||||
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
|
||||
|
||||
## CI/CD Integration
|
||||
- GitHub Actions workflows in `.github/workflows/`
|
||||
- Docker builds supported via `Dockerfile`
|
||||
- Pre-commit hooks enforce ruff formatting and linting
|
||||
|
||||
## Docker Support
|
||||
- Primary deployment method: `docker run soulter/astrbot:latest`
|
||||
- Compose file available: `compose.yml`
|
||||
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
|
||||
- Volume mount required: `./data:/AstrBot/data`
|
||||
|
||||
## Multi-language Support
|
||||
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
|
||||
- UI supports internationalization
|
||||
- Default language is Chinese
|
||||
|
||||
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
|
||||
4
.github/workflows/auto_release.yml
vendored
4
.github/workflows/auto_release.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
|
||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
|
||||
20
.github/workflows/coverage_test.yml
vendored
20
.github/workflows/coverage_test.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
@@ -8,6 +8,7 @@ on:
|
||||
- 'README.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -16,7 +17,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -26,20 +27,19 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
pip install --editable .
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir data/temp
|
||||
mkdir -p data/plugins
|
||||
mkdir -p data/config
|
||||
mkdir -p data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
||||
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
14
.github/workflows/dashboard_ci.yml
vendored
14
.github/workflows/dashboard_ci.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: npm install, build
|
||||
run: |
|
||||
@@ -25,6 +25,8 @@ jobs:
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
mkdir -p dashboard/dist/assets
|
||||
echo $COMMIT_SHA > dashboard/dist/assets/version
|
||||
cd dashboard
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -33,3 +35,13 @@ jobs:
|
||||
path: |
|
||||
dashboard/dist
|
||||
!dist/**/*.md
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
tag: release-${{ github.sha }}
|
||||
owner: AstrBotDevs
|
||||
repo: astrbot-release-harbour
|
||||
body: "Automated release from commit ${{ github.sha }}"
|
||||
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
2
.github/workflows/docker-image.yml
vendored
2
.github/workflows/docker-image.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Pull The Codes
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
|
||||
|
||||
100
README.md
100
README.md
@@ -1,4 +1,4 @@
|
||||
<p align="center">
|
||||
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
|
||||
|
||||

|
||||
|
||||
@@ -25,59 +25,52 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
>
|
||||
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
<details><summary>1. AstrBot 现已自带知识库能力</summary>
|
||||
|
||||
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
> [!NOTE]
|
||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
> [!TIP]
|
||||
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
|
||||
## ✨ 使用方式
|
||||
|
||||
#### Docker 部署
|
||||
|
||||
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||
|
||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||
|
||||
#### 1Panel 部署
|
||||
|
||||
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||
|
||||
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
|
||||
|
||||
#### 在 雨云 上部署
|
||||
|
||||
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
@@ -101,31 +94,14 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者,直接通过 uvx 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
mkdir astrbot && cd astrbot
|
||||
uvx astrbot init
|
||||
# uvx astrbot run
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### 在 雨云 上部署
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| 微信个人号 | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
@@ -144,7 +120,7 @@ uvx astrbot init
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
@@ -152,6 +128,8 @@ uvx astrbot init
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
@@ -244,11 +222,5 @@ _✨ WebUI ✨_
|
||||

|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||
3. Please ensure compliance with local laws and regulations when using this project.
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
15
README_ja.md
15
README_ja.md
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
## ✨ 主な機能
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
@@ -35,7 +35,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
|
||||
> [!TIP]
|
||||
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
>
|
||||
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||
|
||||
## ✨ 使用方法
|
||||
@@ -136,11 +136,11 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> [!TIP]
|
||||
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
||||
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
@@ -3,5 +3,18 @@ from astrbot import logger
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_agent as agent
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"agent",
|
||||
"sp",
|
||||
"ToolSet",
|
||||
"FunctionTool",
|
||||
"BaseFunctionToolExecutor",
|
||||
]
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "3.5.8"
|
||||
__version__ = "3.5.23"
|
||||
|
||||
@@ -117,6 +117,9 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
if "desc" not in metadata and "description" in metadata:
|
||||
metadata["desc"] = metadata["description"]
|
||||
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import asyncio
|
||||
from .log import LogManager, LogBroker # noqa
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
@@ -21,7 +20,7 @@ html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences()
|
||||
sp = SharedPreferences(db_helper=db_helper)
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
|
||||
13
astrbot/core/agent/agent.py
Normal file
13
astrbot/core/agent/agent.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from .tool import FunctionTool
|
||||
from typing import Generic
|
||||
from .run_context import TContext
|
||||
from .hooks import BaseAgentRunHooks
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent(Generic[TContext]):
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
tools: list[str, FunctionTool] | None = None
|
||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||
34
astrbot/core/agent/handoff.py
Normal file
34
astrbot/core/agent/handoff.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Generic
|
||||
from .tool import FunctionTool
|
||||
from .agent import Agent
|
||||
from .run_context import TContext
|
||||
|
||||
|
||||
class HandoffTool(FunctionTool, Generic[TContext]):
|
||||
"""Handoff tool for delegating tasks to another agent."""
|
||||
|
||||
def __init__(
|
||||
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
|
||||
):
|
||||
self.agent = agent
|
||||
super().__init__(
|
||||
name=f"transfer_to_{agent.name}",
|
||||
parameters=parameters or self.default_parameters(),
|
||||
description=agent.instructions or self.default_description(agent.name),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def default_parameters(self) -> dict:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def default_description(self, agent_name: str | None) -> str:
|
||||
agent_name = agent_name or "another"
|
||||
return f"Delegate tasks to {self.name} agent to handle the request."
|
||||
27
astrbot/core/agent/hooks.py
Normal file
27
astrbot/core/agent/hooks.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import mcp
|
||||
from dataclasses import dataclass
|
||||
from .run_context import ContextWrapper, TContext
|
||||
from typing import Generic
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAgentRunHooks(Generic[TContext]):
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
||||
async def on_tool_start(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
): ...
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
tool_result: mcp.types.CallToolResult | None,
|
||||
): ...
|
||||
async def on_agent_done(
|
||||
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
|
||||
): ...
|
||||
208
astrbot/core/agent/mcp_client.py
Normal file
208
astrbot/core/agent/mcp_client.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
12
astrbot/core/agent/response.py
Normal file
12
astrbot/core/agent/response.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
17
astrbot/core/agent/run_context.py
Normal file
17
astrbot/core/agent/run_context.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
event: AstrMessageEvent
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
3
astrbot/core/agent/runners/__init__.py
Normal file
3
astrbot/core/agent/runners/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAgentRunner
|
||||
|
||||
__all__ = ["BaseAgentRunner"]
|
||||
@@ -1,32 +1,33 @@
|
||||
import abc
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ....message.message_event_result import MessageChain
|
||||
from enum import Enum, auto
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponse
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Agent 状态枚举"""
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
"""Defines the state of the agent."""
|
||||
|
||||
IDLE = auto() # Initial state
|
||||
RUNNING = auto() # Currently processing
|
||||
DONE = auto() # Completed
|
||||
ERROR = auto() # Error state
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
class BaseAgentRunner:
|
||||
class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(self) -> None:
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
@@ -1,10 +1,12 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
|
||||
from ...context import PipelineContext
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentState
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponseData
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -21,8 +23,8 @@ from mcp.types import (
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
@@ -31,28 +33,25 @@ else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# TODO:
|
||||
# 1. 处理平台不兼容的处理器
|
||||
|
||||
|
||||
class ToolLoopAgent(BaseAgentRunner):
|
||||
def __init__(
|
||||
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.req = None
|
||||
self.event = event
|
||||
self.pipeline_ctx = pipeline_ctx
|
||||
self._state = AgentState.IDLE
|
||||
self.final_llm_resp = None
|
||||
self.streaming = False
|
||||
|
||||
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
@override
|
||||
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
|
||||
self.req = req
|
||||
self.streaming = streaming
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.tool_executor = tool_executor
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
@@ -78,6 +77,12 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
@@ -124,12 +129,10 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
# 执行事件钩子
|
||||
if await self.pipeline_ctx.call_event_hook(
|
||||
self.event, EventType.OnLLMResponseEvent, llm_resp
|
||||
):
|
||||
return
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回 LLM 结果
|
||||
if llm_resp.result_chain:
|
||||
@@ -193,50 +196,33 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if not res:
|
||||
continue
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||
|
||||
executor = self.tool_executor.execute(
|
||||
tool=func_tool,
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
)
|
||||
async for resp in executor:
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
@@ -247,43 +233,85 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
else:
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
# 尝试调用工具函数
|
||||
wrapper = self.pipeline_ctx.call_handler(
|
||||
self.event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
# Tool 返回结果
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resp)
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
).base64_image(res.content[0].data)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
self.event.clear_result()
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context,
|
||||
func_tool_name,
|
||||
func_tool_args,
|
||||
resp,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.run_context.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
|
||||
self.run_context.event.clear_result()
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
256
astrbot/core/agent/tool.py
Normal file
256
astrbot/core/agent/tool.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
from typing import Awaitable, Literal, Any, Optional
|
||||
from .mcp_client import MCPClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
|
||||
name: str | None = None
|
||||
parameters: dict | None = None
|
||||
description: str | None = None
|
||||
handler: Awaitable | None = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str | None = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str | None = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient | None = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
def __dict__(self) -> dict[str, Any]:
|
||||
"""将 FunctionTool 转换为字典格式"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"parameters": self.parameters,
|
||||
"description": self.description,
|
||||
"active": self.active,
|
||||
"origin": self.origin,
|
||||
"mcp_server_name": self.mcp_server_name,
|
||||
}
|
||||
|
||||
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self.tools) == 0
|
||||
|
||||
def add_tool(self, tool: FunctionTool):
|
||||
"""Add a tool to the set."""
|
||||
# 检查是否已存在同名工具
|
||||
for i, existing_tool in enumerate(self.tools):
|
||||
if existing_tool.name == tool.name:
|
||||
self.tools[i] = tool
|
||||
return
|
||||
self.tools.append(tool)
|
||||
|
||||
def remove_tool(self, name: str):
|
||||
"""Remove a tool by its name."""
|
||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||
|
||||
def get_tool(self, name: str) -> Optional[FunctionTool]:
|
||||
"""Get a tool by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FunctionTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
self.add_tool(_func)
|
||||
|
||||
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
|
||||
def remove_func(self, name: str):
|
||||
"""Remove a function tool by its name."""
|
||||
self.remove_tool(name)
|
||||
|
||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||
def get_func(self, name: str) -> list[FunctionTool]:
|
||||
"""Get all function tools."""
|
||||
return self.get_tool(name)
|
||||
|
||||
@property
|
||||
def func_list(self) -> list[FunctionTool]:
|
||||
"""Get the list of function tools."""
|
||||
return self.tools
|
||||
|
||||
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
|
||||
"""Convert tools to OpenAI API function calling schema format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
func_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
},
|
||||
}
|
||||
|
||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
return result
|
||||
|
||||
def anthropic_schema(self) -> list[dict]:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
tool_def = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
def google_schema(self) -> dict:
|
||||
"""Convert tools to Google GenAI API format."""
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""Convert schema to Gemini API format."""
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties:
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": convert_schema(tool.parameters),
|
||||
}
|
||||
for tool in self.tools
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
|
||||
return self.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
|
||||
def get_func_desc_anthropic_style(self):
|
||||
return self.anthropic_schema()
|
||||
|
||||
@deprecated(reason="Use google_schema() instead", version="4.0.0")
|
||||
def get_func_desc_google_genai_style(self):
|
||||
return self.google_schema()
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""获取所有工具的名称列表"""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tools)
|
||||
|
||||
def __bool__(self):
|
||||
return len(self.tools) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
|
||||
def __str__(self):
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
11
astrbot/core/agent/tool_executor.py
Normal file
11
astrbot/core/agent/tool_executor.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import mcp
|
||||
from typing import Any, Generic, AsyncGenerator
|
||||
from .run_context import TContext, ContextWrapper
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class BaseFunctionToolExecutor(Generic[TContext]):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
|
||||
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...
|
||||
11
astrbot/core/astr_agent_context.py
Normal file
11
astrbot/core/astr_agent_context.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
provider: Provider
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
276
astrbot/core/astrbot_config_mgr.py
Normal file
276
astrbot/core/astrbot_config_mgr.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.core import AstrBotConfig, logger
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||
from typing import TypeVar, TypedDict
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class ConfInfo(TypedDict):
|
||||
"""Configuration information for a specific session or platform."""
|
||||
|
||||
id: str # UUID of the configuration or "default"
|
||||
umop: list[str] # Unified Message Origin Pattern
|
||||
name: str
|
||||
path: str # File name to the configuration file
|
||||
|
||||
|
||||
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
id="default",
|
||||
umop=["::"],
|
||||
name="default",
|
||||
path=ASTRBOT_CONFIG_PATH,
|
||||
)
|
||||
|
||||
|
||||
class AstrBotConfigManager:
|
||||
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
||||
|
||||
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
|
||||
self.sp = sp
|
||||
self.confs: dict[str, AstrBotConfig] = {}
|
||||
"""uuid / "default" -> AstrBotConfig"""
|
||||
self.confs["default"] = default_config
|
||||
self._load_all_configs()
|
||||
|
||||
def _load_all_configs(self):
|
||||
"""Load all configurations from the shared preferences."""
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
for uuid_, meta in abconf_data.items():
|
||||
filename = meta["path"]
|
||||
conf_path = os.path.join(get_astrbot_config_path(), filename)
|
||||
if os.path.exists(conf_path):
|
||||
conf = AstrBotConfig(config_path=conf_path)
|
||||
self.confs[uuid_] = conf
|
||||
else:
|
||||
logger.warning(
|
||||
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||
|
||||
Returns:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
"""
|
||||
# uuid -> { "umop": list, "path": str, "name": str }
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
try:
|
||||
umo = str(MessageSession.from_str(umo)) # validate
|
||||
except Exception:
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
for uuid_, meta in abconf_data.items():
|
||||
for pattern in meta["umop"]:
|
||||
if self._is_umo_match(pattern, umo):
|
||||
return ConfInfo(**meta, id=uuid_)
|
||||
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
def _save_conf_mapping(
|
||||
self,
|
||||
abconf_path: str,
|
||||
abconf_id: str,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
abconf_name: str | None = None,
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
"umop": umo_parts,
|
||||
"path": abconf_path,
|
||||
"name": random_word,
|
||||
}
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
|
||||
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
||||
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
||||
if not umo:
|
||||
return self.confs["default"]
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
uuid_ = self._load_conf_mapping(umo)["id"]
|
||||
|
||||
conf = self.confs.get(uuid_)
|
||||
if not conf:
|
||||
conf = self.confs["default"] # default MUST exists
|
||||
|
||||
return conf
|
||||
|
||||
@property
|
||||
def default_conf(self) -> AstrBotConfig:
|
||||
"""获取默认配置文件"""
|
||||
return self.confs["default"]
|
||||
|
||||
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件元数据"""
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
return self._load_conf_mapping(umo)
|
||||
|
||||
def get_conf_list(self) -> list[ConfInfo]:
|
||||
"""获取所有配置文件的元数据列表"""
|
||||
conf_list = []
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
abconf_mapping = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
for uuid_, meta in abconf_mapping.items():
|
||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||
return conf_list
|
||||
|
||||
def create_conf(
|
||||
self,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
config: dict = DEFAULT_CONFIG,
|
||||
name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
|
||||
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
||||
"""
|
||||
conf_uuid = str(uuid.uuid4())
|
||||
conf_file_name = f"abconf_{conf_uuid}.json"
|
||||
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
||||
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
||||
conf.save_config()
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
|
||||
self.confs[conf_uuid] = conf
|
||||
return conf_uuid
|
||||
|
||||
def delete_conf(self, conf_id: str) -> bool:
|
||||
"""删除指定配置文件
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
|
||||
Raises:
|
||||
ValueError: 如果试图删除默认配置文件
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能删除默认配置文件")
|
||||
|
||||
# 从映射中移除
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
|
||||
# 获取配置文件路径
|
||||
conf_path = os.path.join(
|
||||
get_astrbot_config_path(), abconf_data[conf_id]["path"]
|
||||
)
|
||||
|
||||
# 删除配置文件
|
||||
try:
|
||||
if os.path.exists(conf_path):
|
||||
os.remove(conf_path)
|
||||
logger.info(f"已删除配置文件: {conf_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
|
||||
return False
|
||||
|
||||
# 从内存中移除
|
||||
if conf_id in self.confs:
|
||||
del self.confs[conf_id]
|
||||
|
||||
# 从映射中移除
|
||||
del abconf_data[conf_id]
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
||||
) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
name: 新的配置文件名称 (可选)
|
||||
umo_parts: 新的 UMO 部分列表 (可选)
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能更新默认配置文件的信息")
|
||||
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
|
||||
# 更新名称
|
||||
if name is not None:
|
||||
abconf_data[conf_id]["name"] = name
|
||||
|
||||
# 更新 UMO 部分
|
||||
if umo_parts is not None:
|
||||
# 验证 UMO 部分格式
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
logger.info(f"成功更新配置文件 {conf_id} 的信息")
|
||||
return True
|
||||
|
||||
def g(
|
||||
self, umo: str | None = None, key: str | None = None, default: _VT = None
|
||||
) -> _VT:
|
||||
"""获取配置项。umo 为 None 时使用默认配置"""
|
||||
if umo is None:
|
||||
return self.confs["default"].get(key, default)
|
||||
conf = self.get_conf(umo)
|
||||
return conf.get(key, default)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,40 +5,44 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
||||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.db.po import Conversation, ConversationV2
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.session_conversations: Dict[str, str] = {}
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
"""启动定时保存任务"""
|
||||
asyncio.create_task(self._periodic_save())
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||
created_at = int(conv_v2.created_at.timestamp())
|
||||
updated_at = int(conv_v2.updated_at.timestamp())
|
||||
return Conversation(
|
||||
platform_id=conv_v2.platform_id,
|
||||
user_id=conv_v2.user_id,
|
||||
cid=conv_v2.conversation_id,
|
||||
history=json.dumps(conv_v2.content or []),
|
||||
title=conv_v2.title,
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
async def _periodic_save(self):
|
||||
"""定时保存会话对话映射关系到存储中"""
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
"""保存会话对话映射关系到存储中"""
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
async def new_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
platform_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
) -> str:
|
||||
"""新建对话,并将当前会话的对话转移到新对话
|
||||
|
||||
Args:
|
||||
@@ -46,11 +50,23 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return conversation_id
|
||||
if not platform_id:
|
||||
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
||||
parts = unified_msg_origin.split(":")
|
||||
if len(parts) >= 3:
|
||||
platform_id = parts[0]
|
||||
if not platform_id:
|
||||
platform_id = "unknown"
|
||||
conv = await self.db.create_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
platform_id=platform_id,
|
||||
content=content,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conv.conversation_id
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
||||
return conv.conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
"""切换会话的对话
|
||||
@@ -60,10 +76,10 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||||
|
||||
async def delete_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str = None
|
||||
self, unified_msg_origin: str, conversation_id: str | None = None
|
||||
):
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
@@ -71,13 +87,18 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
f = False
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
f = True
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
del self.session_conversations[unified_msg_origin]
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
Args:
|
||||
@@ -85,14 +106,19 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
ret = self.session_conversations.get(unified_msg_origin, None)
|
||||
if not ret:
|
||||
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
|
||||
if ret:
|
||||
self.session_conversations[unified_msg_origin] = ret
|
||||
return ret
|
||||
|
||||
async def get_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> Conversation:
|
||||
) -> Conversation | None:
|
||||
"""获取会话的对话
|
||||
|
||||
Args:
|
||||
@@ -101,27 +127,74 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation (Conversation): 对话对象
|
||||
"""
|
||||
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
if not conv and create_if_not_exists:
|
||||
# 如果对话不存在且需要创建,则新建一个对话
|
||||
conversation_id = await self.new_conversation(unified_msg_origin)
|
||||
return self.db.get_conversation_by_user_id(
|
||||
unified_msg_origin, conversation_id
|
||||
)
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
conv_res = None
|
||||
if conv:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
return conv_res
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
"""获取会话的所有对话
|
||||
async def get_conversations(
|
||||
self, unified_msg_origin: str | None = None, platform_id: str | None = None
|
||||
) -> List[Conversation]:
|
||||
"""获取对话列表
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
||||
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
||||
Returns:
|
||||
conversations (List[Conversation]): 对话对象列表
|
||||
"""
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
convs = await self.db.get_conversations(
|
||||
user_id=unified_msg_origin, platform_id=platform_id
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
convs_res.append(conv_res)
|
||||
return convs_res
|
||||
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platform_ids: list[str] | None = None,
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""获取过滤后的对话列表
|
||||
|
||||
Args:
|
||||
page (int): 页码, 默认为 1
|
||||
page_size (int): 每页大小, 默认为 20
|
||||
platform_ids (list[str]): 平台 ID 列表, 可选
|
||||
search_query (str): 搜索查询字符串, 可选
|
||||
Returns:
|
||||
conversations (list[Conversation]): 对话对象列表
|
||||
"""
|
||||
convs, cnt = await self.db.get_filtered_conversations(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
platform_ids=platform_ids,
|
||||
search_query=search_query,
|
||||
**kwargs,
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
convs_res.append(conv_res)
|
||||
return convs_res, cnt
|
||||
|
||||
async def update_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str | None = None,
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话
|
||||
|
||||
@@ -130,40 +203,55 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
"""
|
||||
if not conversation_id:
|
||||
# 如果没有提供 conversation_id,则获取当前的
|
||||
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
await self.db.update_conversation(
|
||||
cid=conversation_id,
|
||||
history=json.dumps(history),
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
async def update_conversation_title(
|
||||
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
|
||||
):
|
||||
"""更新会话的对话标题
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
title (str): 对话标题
|
||||
|
||||
Deprecated:
|
||||
Use `update_conversation` with `title` parameter instead.
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
user_id=unified_msg_origin, cid=conversation_id, title=title
|
||||
)
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(
|
||||
self, unified_msg_origin: str, persona_id: str
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
persona_id: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话 Persona ID
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
persona_id (str): 对话 Persona ID
|
||||
|
||||
Deprecated:
|
||||
Use `update_conversation` with `persona_id` parameter instead.
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
|
||||
)
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||
|
||||
@@ -15,20 +15,23 @@ import time
|
||||
import threading
|
||||
import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from . import astrbot_config, html_renderer
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
|
||||
@@ -47,12 +50,23 @@ class AstrBotCoreLifecycle:
|
||||
self.db = db # 初始化数据库
|
||||
|
||||
# 设置代理
|
||||
if self.astrbot_config.get("http_proxy", ""):
|
||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||
if proxy := os.environ.get("https_proxy"):
|
||||
logger.debug(f"Using proxy: {proxy}")
|
||||
os.environ["no_proxy"] = "localhost"
|
||||
proxy_config = self.astrbot_config.get("http_proxy", "")
|
||||
if proxy_config != "":
|
||||
os.environ["https_proxy"] = proxy_config
|
||||
os.environ["http_proxy"] = proxy_config
|
||||
logger.debug(f"Using proxy: {proxy_config}")
|
||||
# 设置 no_proxy
|
||||
no_proxy_list = self.astrbot_config.get("no_proxy", [])
|
||||
os.environ["no_proxy"] = ",".join(no_proxy_list)
|
||||
else:
|
||||
# 清空代理环境变量
|
||||
if "https_proxy" in os.environ:
|
||||
del os.environ["https_proxy"]
|
||||
if "http_proxy" in os.environ:
|
||||
del os.environ["http_proxy"]
|
||||
if "no_proxy" in os.environ:
|
||||
del os.environ["no_proxy"]
|
||||
logger.debug("HTTP proxy cleared")
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
@@ -66,11 +80,26 @@ class AstrBotCoreLifecycle:
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||
|
||||
await self.db.initialize()
|
||||
|
||||
await html_renderer.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
default_config=self.astrbot_config, sp=sp
|
||||
)
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
# 初始化人格管理器
|
||||
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
|
||||
await self.persona_mgr.initialize()
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||
self.provider_manager = ProviderManager(
|
||||
self.astrbot_config_mgr, self.db, self.persona_mgr
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
@@ -78,6 +107,9 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
# 初始化平台消息历史管理器
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
@@ -86,6 +118,9 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.platform_message_history_manager,
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -98,16 +133,16 @@ class AstrBotCoreLifecycle:
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler = PipelineScheduler(
|
||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||
)
|
||||
await self.pipeline_scheduler.initialize()
|
||||
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
self.event_bus = EventBus(
|
||||
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
|
||||
)
|
||||
|
||||
# 记录启动时间
|
||||
self.start_time = int(time.time())
|
||||
@@ -224,6 +259,39 @@ class AstrBotCoreLifecycle:
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
tasks.append(
|
||||
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
|
||||
asyncio.create_task(
|
||||
platform_inst.run(),
|
||||
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
|
||||
"""加载消息事件流水线调度器
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
"""
|
||||
mapping = {}
|
||||
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
)
|
||||
await scheduler.initialize()
|
||||
mapping[conf_id] = scheduler
|
||||
return mapping
|
||||
|
||||
async def reload_pipeline_scheduler(self, conf_id: str):
|
||||
"""重新加载消息事件流水线调度器
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
"""
|
||||
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
||||
if not ab_config:
|
||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
)
|
||||
await scheduler.initialize()
|
||||
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
||||
|
||||
@@ -1,7 +1,20 @@
|
||||
import abc
|
||||
import datetime
|
||||
import typing as T
|
||||
from deprecated import deprecated
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||
from astrbot.core.db.po import (
|
||||
Stats,
|
||||
PlatformStat,
|
||||
ConversationV2,
|
||||
PlatformMessageHistory,
|
||||
Attachment,
|
||||
Persona,
|
||||
Preference,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -10,152 +23,262 @@ class BaseDatabase(abc.ABC):
|
||||
数据库基类
|
||||
"""
|
||||
|
||||
DATABASE_URL = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化数据库连接"""
|
||||
pass
|
||||
|
||||
def insert_base_metrics(self, metrics: dict):
|
||||
"""插入基础指标数据"""
|
||||
self.insert_platform_metrics(metrics["platform_stats"])
|
||||
self.insert_plugin_metrics(metrics["plugin_stats"])
|
||||
self.insert_command_metrics(metrics["command_stats"])
|
||||
self.insert_llm_metrics(metrics["llm_stats"])
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
"""插入平台指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_plugin_metrics(self, metrics: dict):
|
||||
"""插入插件指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_command_metrics(self, metrics: dict):
|
||||
"""插入指令指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
"""插入 LLM 指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_llm_history(
|
||||
self, session_id: str = None, provider_type: str = None
|
||||
) -> List[LLMHistory]:
|
||||
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
|
||||
raise NotImplementedError
|
||||
@asynccontextmanager
|
||||
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session."""
|
||||
if not self.inited:
|
||||
await self.initialize()
|
||||
self.inited = True
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取基础统计数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_total_message_count(self) -> int:
|
||||
"""获取总消息数"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取基础统计数据(合并)"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_atri_vision_data(self, vision_data: ATRIVision):
|
||||
"""插入 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
# New methods in v4.0.0
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data(self) -> List[ATRIVision]:
|
||||
"""获取 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime.datetime | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data_by_path_or_id(
|
||||
self, url_or_path: str, id: str
|
||||
) -> ATRIVision:
|
||||
"""通过 url 或 path 获取 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
async def count_platform_stats(self) -> int:
|
||||
"""Count the number of platform statistics records."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
"""通过 user_id 和 cid 获取 Conversation"""
|
||||
raise NotImplementedError
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
"""新建 Conversation"""
|
||||
raise NotImplementedError
|
||||
async def get_conversations(
|
||||
self, user_id: str | None = None, platform_id: str | None = None
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations for a specific user and platform_id(optional).
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
"""更新 Conversation"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
"""删除 Conversation"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
"""更新 Conversation 标题"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
"""更新 Conversation Persona ID"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||
content is not included in the result.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_filtered_conversations(
|
||||
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
|
||||
"""Get a specific conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations with pagination."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表
|
||||
platform_ids: list[str] | None = None,
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationV2], int]:
|
||||
"""Get conversations filtered by platform IDs and search query."""
|
||||
...
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
platforms: 平台筛选列表
|
||||
message_types: 消息类型筛选列表
|
||||
search_query: 搜索关键词
|
||||
exclude_ids: 排除的用户ID列表
|
||||
exclude_platforms: 排除的平台列表
|
||||
@abc.abstractmethod
|
||||
async def create_conversation(
|
||||
self,
|
||||
user_id: str,
|
||||
platform_id: str,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
cid: str | None = None,
|
||||
created_at: datetime.datetime | None = None,
|
||||
updated_at: datetime.datetime | None = None,
|
||||
) -> ConversationV2:
|
||||
"""Create a new conversation."""
|
||||
...
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@abc.abstractmethod
|
||||
async def update_conversation(
|
||||
self,
|
||||
cid: str,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_conversation(self, cid: str) -> None:
|
||||
"""Delete a conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_platform_message_offset(
|
||||
self, platform_id: str, user_id: str, offset_sec: int = 86400
|
||||
) -> None:
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformMessageHistory]:
|
||||
"""Get platform message history for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_attachment(
|
||||
self,
|
||||
path: str,
|
||||
type: str,
|
||||
mime_type: str,
|
||||
):
|
||||
"""Insert a new attachment record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
|
||||
"""Get an attachment by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""Insert a new persona record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_persona_by_id(self, persona_id: str) -> Persona:
|
||||
"""Get a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_personas(self) -> list[Persona]:
|
||||
"""Get all personas for a specific bot."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str | None = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona | None:
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_persona(self, persona_id: str) -> None:
|
||||
"""Delete a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_preference_or_update(
|
||||
self, scope: str, scope_id: str, key: str, value: dict
|
||||
) -> Preference:
|
||||
"""Insert a new preference record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
|
||||
"""Get a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preferences(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
|
||||
"""Remove a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear_preferences(self, scope: str, scope_id: str) -> None:
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
# cid: str,
|
||||
# role: str,
|
||||
# content: list,
|
||||
# tool_calls: list = None,
|
||||
# tool_call_id: str = None,
|
||||
# parent_id: str = None,
|
||||
# ) -> LLMMessage:
|
||||
# """Insert a new LLM message into the conversation."""
|
||||
# ...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||
# """Get all LLM messages for a specific conversation."""
|
||||
# ...
|
||||
|
||||
64
astrbot/core/db/migration/helper.py
Normal file
64
astrbot/core/db/migration/helper.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.api import logger, sp
|
||||
from .migra_3_to_4 import (
|
||||
migration_conversation_table,
|
||||
migration_platform_table,
|
||||
migration_webchat_data,
|
||||
migration_persona_data,
|
||||
migration_preferences,
|
||||
)
|
||||
|
||||
|
||||
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
"""
|
||||
检查是否需要进行数据库迁移
|
||||
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
|
||||
"""
|
||||
data_v3_exists = os.path.exists(get_astrbot_data_path())
|
||||
if not data_v3_exists:
|
||||
return False
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_v4"
|
||||
)
|
||||
if migration_done:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def do_migration_v4(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
astrbot_config: AstrBotConfig,
|
||||
):
|
||||
"""
|
||||
执行数据库迁移
|
||||
迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
if not await check_migration_needed_v4(db_helper):
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移...")
|
||||
|
||||
# 执行会话表迁移
|
||||
await migration_conversation_table(db_helper, platform_id_map)
|
||||
|
||||
# 执行人格数据迁移
|
||||
await migration_persona_data(db_helper, astrbot_config)
|
||||
|
||||
# 执行 WebChat 数据迁移
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行偏好设置迁移
|
||||
await migration_preferences(db_helper,platform_id_map)
|
||||
|
||||
# 执行平台统计表迁移
|
||||
await migration_platform_table(db_helper, platform_id_map)
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_v4", True)
|
||||
|
||||
logger.info("数据库迁移完成。")
|
||||
338
astrbot/core/db/migration/migra_3_to_4.py
Normal file
338
astrbot/core/db/migration/migra_3_to_4.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
import datetime
|
||||
from .. import BaseDatabase
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
from .shared_preferences_v3 import sp as sp_v3
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
||||
from sqlalchemy import text
|
||||
|
||||
"""
|
||||
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
2. 迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
|
||||
|
||||
def get_platform_id(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
{"platform_id": old_platform_name, "platform_type": old_platform_name},
|
||||
).get("platform_id", old_platform_name)
|
||||
|
||||
|
||||
def get_platform_type(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
{"platform_id": old_platform_name, "platform_type": old_platform_name},
|
||||
).get("platform_type", old_platform_name)
|
||||
|
||||
|
||||
async def migration_conversation_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, session.platform_name
|
||||
)
|
||||
session.platform_id = platform_id # 更新平台名称为新的 ID
|
||||
conv_v2 = ConversationV2(
|
||||
user_id=str(session),
|
||||
content=json.loads(conv.history) if conv.history else [],
|
||||
platform_id=platform_id,
|
||||
title=conv.title,
|
||||
persona_id=conv.persona_id,
|
||||
conversation_id=conv.cid,
|
||||
created_at=datetime.datetime.fromtimestamp(conv.created_at),
|
||||
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
|
||||
)
|
||||
dbsession.add(conv_v2)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_platform_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
secs_from_2023_4_10_to_now = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
|
||||
).total_seconds()
|
||||
offset_sec = int(secs_from_2023_4_10_to_now)
|
||||
logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。")
|
||||
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
|
||||
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
|
||||
platform_stats_v3 = stats.platform
|
||||
|
||||
if not platform_stats_v3:
|
||||
logger.info("没有找到旧平台数据,跳过迁移。")
|
||||
return
|
||||
|
||||
first_time_stamp = platform_stats_v3[0].timestamp
|
||||
end_time_stamp = platform_stats_v3[-1].timestamp
|
||||
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
|
||||
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
|
||||
|
||||
idx = 0
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
total_buckets = (end_time - start_time) // 3600
|
||||
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
|
||||
if bucket_idx % 500 == 0:
|
||||
progress = int((bucket_idx + 1) / total_buckets * 100)
|
||||
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(platform_stats_v3)
|
||||
and platform_stats_v3[idx].timestamp < bucket_end
|
||||
):
|
||||
cnt += platform_stats_v3[idx].count
|
||||
idx += 1
|
||||
if cnt == 0:
|
||||
continue
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
platform_type = get_platform_type(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
try:
|
||||
await dbsession.execute(
|
||||
text("""
|
||||
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
||||
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
||||
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
|
||||
count = platform_stats.count + EXCLUDED.count
|
||||
"""),
|
||||
{
|
||||
"timestamp": datetime.datetime.fromtimestamp(
|
||||
bucket_end, tz=datetime.timezone.utc
|
||||
),
|
||||
"platform_id": platform_id,
|
||||
"platform_type": platform_type,
|
||||
"count": cnt,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
|
||||
|
||||
|
||||
async def migration_webchat_data(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
history = json.loads(conv.history) if conv.history else []
|
||||
for msg in history:
|
||||
type_ = msg.get("type") # user type, "bot" or "user"
|
||||
new_history = PlatformMessageHistory(
|
||||
platform_id=platform_id,
|
||||
user_id=conv.cid, # we use conv.cid as user_id for webchat
|
||||
content=msg,
|
||||
sender_id=type_,
|
||||
sender_name=type_,
|
||||
)
|
||||
dbsession.add(new_history)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_persona_data(
|
||||
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
|
||||
):
|
||||
"""
|
||||
迁移 Persona 数据到新的表中。
|
||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||
"""
|
||||
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
||||
total_personas = len(v3_persona_config)
|
||||
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
|
||||
|
||||
for idx, persona in enumerate(v3_persona_config):
|
||||
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
|
||||
progress = int((idx + 1) / total_personas * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
|
||||
try:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
mood_prompt = ""
|
||||
user_turn = True
|
||||
for mood_dialog in mood_imitation_dialogs:
|
||||
if user_turn:
|
||||
mood_prompt += f"A: {mood_dialog}\n"
|
||||
else:
|
||||
mood_prompt += f"B: {mood_dialog}\n"
|
||||
user_turn = not user_turn
|
||||
system_prompt = persona.get("prompt", "")
|
||||
if mood_prompt:
|
||||
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
||||
persona_new = await db_helper.insert_persona(
|
||||
persona_id=persona["name"],
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs,
|
||||
)
|
||||
logger.info(
|
||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
|
||||
async def migration_preferences(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
# 1. global scope migration
|
||||
keys = [
|
||||
"inactivated_llm_tools",
|
||||
"inactivated_plugins",
|
||||
"curr_provider",
|
||||
"curr_provider_tts",
|
||||
"curr_provider_stt",
|
||||
"alter_cmd",
|
||||
]
|
||||
for key in keys:
|
||||
value = sp_v3.get(key)
|
||||
if value is not None:
|
||||
await sp.put_async("global", "global", key, value)
|
||||
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
|
||||
|
||||
# 2. umo scope migration
|
||||
session_conversation = sp_v3.get("session_conversation", default={})
|
||||
for umo, conversation_id in session_conversation.items():
|
||||
if not umo or not conversation_id:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
|
||||
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
|
||||
|
||||
session_service_config = sp_v3.get("session_service_config", default={})
|
||||
for umo, config in session_service_config.items():
|
||||
if not umo or not config:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
await sp.put_async("umo", str(session), "session_service_config", config)
|
||||
|
||||
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
|
||||
|
||||
session_variables = sp_v3.get("session_variables", default={})
|
||||
for umo, variables in session_variables.items():
|
||||
if not umo or not variables:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "session_variables", variables)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
|
||||
|
||||
session_provider_perf = sp_v3.get("session_provider_perf", default={})
|
||||
for umo, perf in session_provider_perf.items():
|
||||
if not umo or not perf:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
for provider_type, provider_id in perf.items():
|
||||
await sp.put_async(
|
||||
"umo", str(session), f"provider_perf_{provider_type}", provider_id
|
||||
)
|
||||
logger.info(
|
||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
||||
45
astrbot/core/db/migration/shared_preferences_v3.py
Normal file
45
astrbot/core/db/migration/shared_preferences_v3.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
sp = SharedPreferences()
|
||||
493
astrbot/core/db/migration/sqlite_v3.py
Normal file
493
astrbot/core/db/migration/sqlite_v3.py
Normal file
@@ -0,0 +1,493 @@
|
||||
import sqlite3
|
||||
import time
|
||||
from astrbot.core.db.po import Platform, Stats
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
"""字符串格式的列表。"""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
|
||||
|
||||
INIT_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT, -- 会话 id
|
||||
cid TEXT, -- 对话 id
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
|
||||
PRAGMA encoding = 'UTF-8';
|
||||
"""
|
||||
|
||||
|
||||
class SQLiteDatabase():
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
sql = INIT_SQL
|
||||
|
||||
# 初始化数据库
|
||||
self.conn = self._get_conn(self.db_path)
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
"""
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
"""
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
has_persona_id = False
|
||||
for row in res:
|
||||
if row[1] == "title":
|
||||
has_title = True
|
||||
if row[1] == "persona_id":
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
c.close()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
conn = self._get_conn(self.db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
if params:
|
||||
c.execute(sql, params)
|
||||
c.close()
|
||||
else:
|
||||
c.execute(sql)
|
||||
c.close()
|
||||
|
||||
conn.commit()
|
||||
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT * FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform=platform)
|
||||
|
||||
def get_total_message_count(self) -> int:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT SUM(count) FROM platform
|
||||
"""
|
||||
)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return res[0]
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT name, SUM(count), timestamp FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
+ " GROUP BY name"
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
|
||||
if not res:
|
||||
return
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
for row in res:
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(
|
||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
||||
)
|
||||
return conversations
|
||||
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
"""更新对话,并且同时更新时间"""
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(history, updated_at, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(title, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(persona_id, user_id, cid),
|
||||
)
|
||||
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
try:
|
||||
# 获取总记录数
|
||||
c.execute("""
|
||||
SELECT COUNT(*) FROM webchat_conversation
|
||||
""")
|
||||
total_count = c.fetchone()[0]
|
||||
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取分页数据,按更新时间降序排序
|
||||
c.execute(
|
||||
"""
|
||||
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||
FROM webchat_conversation
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(page_size, offset),
|
||||
)
|
||||
|
||||
rows = c.fetchall()
|
||||
|
||||
conversations = []
|
||||
|
||||
for row in rows:
|
||||
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
|
||||
safe_cid = str(cid) if cid else "unknown"
|
||||
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||
|
||||
conversations.append(
|
||||
{
|
||||
"user_id": user_id or "",
|
||||
"cid": safe_cid,
|
||||
"title": title or f"对话 {display_cid}",
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
except Exception as _:
|
||||
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||
return [], 0
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
where_clauses = []
|
||||
params = []
|
||||
|
||||
# 平台筛选
|
||||
if platforms and len(platforms) > 0:
|
||||
platform_conditions = []
|
||||
for platform in platforms:
|
||||
platform_conditions.append("user_id LIKE ?")
|
||||
params.append(f"{platform}:%")
|
||||
|
||||
if platform_conditions:
|
||||
where_clauses.append(f"({' OR '.join(platform_conditions)})")
|
||||
|
||||
# 消息类型筛选
|
||||
if message_types and len(message_types) > 0:
|
||||
message_type_conditions = []
|
||||
for msg_type in message_types:
|
||||
message_type_conditions.append("user_id LIKE ?")
|
||||
params.append(f"%:{msg_type}:%")
|
||||
|
||||
if message_type_conditions:
|
||||
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
||||
|
||||
# 搜索关键词
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
where_clauses.append(
|
||||
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||
)
|
||||
search_param = f"%{search_query}%"
|
||||
params.extend([search_param, search_param, search_param, search_param])
|
||||
|
||||
# 排除特定用户ID
|
||||
if exclude_ids and len(exclude_ids) > 0:
|
||||
for exclude_id in exclude_ids:
|
||||
where_clauses.append("user_id NOT LIKE ?")
|
||||
params.append(f"{exclude_id}%")
|
||||
|
||||
# 排除特定平台
|
||||
if exclude_platforms and len(exclude_platforms) > 0:
|
||||
for exclude_platform in exclude_platforms:
|
||||
where_clauses.append("user_id NOT LIKE ?")
|
||||
params.append(f"{exclude_platform}:%")
|
||||
|
||||
# 构建完整的 WHERE 子句
|
||||
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||
|
||||
# 构建计数查询
|
||||
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
|
||||
|
||||
# 获取总记录数
|
||||
c.execute(count_sql, params)
|
||||
total_count = c.fetchone()[0]
|
||||
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 构建分页数据查询
|
||||
data_sql = f"""
|
||||
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||
FROM webchat_conversation
|
||||
{where_sql}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_params = params + [page_size, offset]
|
||||
|
||||
# 获取分页数据
|
||||
c.execute(data_sql, query_params)
|
||||
rows = c.fetchall()
|
||||
|
||||
conversations = []
|
||||
|
||||
for row in rows:
|
||||
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||
# 确保 cid 是字符串类型,否则使用一个默认值
|
||||
safe_cid = str(cid) if cid else "unknown"
|
||||
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||
|
||||
conversations.append(
|
||||
{
|
||||
"user_id": user_id or "",
|
||||
"cid": safe_cid,
|
||||
"title": title or f"对话 {display_cid}",
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
except Exception as _:
|
||||
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||
return [], 0
|
||||
finally:
|
||||
c.close()
|
||||
@@ -1,7 +1,233 @@
|
||||
"""指标数据"""
|
||||
import uuid
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
from sqlmodel import (
|
||||
SQLModel,
|
||||
Text,
|
||||
JSON,
|
||||
UniqueConstraint,
|
||||
Field,
|
||||
)
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
|
||||
class PlatformStat(SQLModel, table=True):
|
||||
"""This class represents the statistics of bot usage across different platforms.
|
||||
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
platform_id: str = Field(nullable=False)
|
||||
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
|
||||
count: int = Field(default=0, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"timestamp",
|
||||
"platform_id",
|
||||
"platform_type",
|
||||
name="uix_platform_stats",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conversation_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False)
|
||||
content: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
title: Optional[str] = Field(default=None, max_length=255)
|
||||
persona_id: Optional[str] = Field(default=None)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conversation_id",
|
||||
name="uix_conversation_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Persona(SQLModel, table=True):
|
||||
"""Persona is a set of instructions for LLMs to follow.
|
||||
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""a list of strings, each representing a dialog to start with"""
|
||||
tools: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"persona_id",
|
||||
name="uix_persona_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
scope: str = Field(nullable=False)
|
||||
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
||||
scope_id: str = Field(nullable=False)
|
||||
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
|
||||
key: str = Field(nullable=False)
|
||||
value: dict = Field(sa_type=JSON, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"scope",
|
||||
"scope_id",
|
||||
"key",
|
||||
name="uix_preference_scope_scope_id_key",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PlatformMessageHistory(SQLModel, table=True):
|
||||
"""This class represents the message history for a specific platform.
|
||||
|
||||
It is used to store messages that are not LLM-generated, such as user messages
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
sender_name: Optional[str] = Field(
|
||||
default=None
|
||||
) # Name of the sender in the platform
|
||||
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class Attachment(SQLModel, table=True):
|
||||
"""This class represents attachments for messages in AstrBot.
|
||||
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
path: str = Field(nullable=False) # Path to the file on disk
|
||||
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
|
||||
mime_type: str = Field(nullable=False) # MIME type of the file
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"attachment_id",
|
||||
name="uix_attachment_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
|
||||
对于 WebChat,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
|
||||
在 v4.0.0 版本及之后,WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
|
||||
"""
|
||||
|
||||
platform_id: str
|
||||
user_id: str
|
||||
cid: str
|
||||
"""对话 ID, 是 uuid 格式的字符串"""
|
||||
history: str = ""
|
||||
"""字符串格式的对话列表。"""
|
||||
title: str | None = ""
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
"""LLM 人格类。
|
||||
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
# ====
|
||||
# Deprecated, and will be removed in future versions.
|
||||
# ====
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -13,77 +239,6 @@ class Platform:
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""供应商使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plugin:
|
||||
"""插件使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Command:
|
||||
"""命令使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
platform: List[Platform] = field(default_factory=list)
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMHistory:
|
||||
"""LLM 聊天时持久化的信息"""
|
||||
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ATRIVision:
|
||||
"""Deprecated"""
|
||||
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
is_meme: bool
|
||||
keywords: List[str]
|
||||
platform_name: str
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
"""字符串格式的列表。"""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
platform: list[Platform] = field(default_factory=list)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,50 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT, -- 会话 id
|
||||
cid TEXT, -- 对话 id
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
|
||||
PRAGMA encoding = 'UTF-8';
|
||||
@@ -5,6 +5,7 @@ from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
@@ -17,6 +18,7 @@ class FaissVecDB(BaseVecDB):
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
rerank_provider: RerankProvider | None = None,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
@@ -26,11 +28,14 @@ class FaissVecDB(BaseVecDB):
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
@@ -53,7 +58,12 @@ class FaissVecDB(BaseVecDB):
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
@@ -62,6 +72,7 @@ class FaissVecDB(BaseVecDB):
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
@@ -72,7 +83,6 @@ class FaissVecDB(BaseVecDB):
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
# TODO: rerank
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
@@ -83,7 +93,7 @@ class FaissVecDB(BaseVecDB):
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs = []
|
||||
result_docs: list[Result] = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
@@ -93,7 +103,20 @@ class FaissVecDB(BaseVecDB):
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
return result_docs[:k]
|
||||
|
||||
top_k_results = result_docs[:k]
|
||||
|
||||
if rerank and self.rerank_provider:
|
||||
documents = [doc.data["text"] for doc in top_k_results]
|
||||
reranked_results = await self.rerank_provider.rerank(query, documents)
|
||||
reranked_results = sorted(
|
||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||
)
|
||||
top_k_results = [
|
||||
top_k_results[reranked_result.index] for reranked_result in reranked_results
|
||||
]
|
||||
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
|
||||
@@ -16,30 +16,32 @@ from asyncio import Queue
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
from astrbot.core import logger
|
||||
from .platform import AstrMessageEvent
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""事件总线: 用于处理事件的分发和处理
|
||||
"""用于处理事件的分发和处理"""
|
||||
|
||||
维护一个异步队列, 来接受各种消息事件
|
||||
"""
|
||||
|
||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
async def dispatch(self):
|
||||
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||
while True:
|
||||
event: AstrMessageEvent = (
|
||||
await self.event_queue.get()
|
||||
) # 从事件队列中获取新的事件
|
||||
self._print_event(event) # 打印日志
|
||||
asyncio.create_task(
|
||||
self.pipeline_scheduler.execute(event)
|
||||
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent):
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
@@ -48,10 +50,10 @@ class EventBus:
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||
else:
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from urllib.parse import urlparse, unquote
|
||||
import platform
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
@@ -15,7 +17,9 @@ class FileTokenService:
|
||||
async def _cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
||||
expired_tokens = [
|
||||
token for token, (_, expire) in self.staged_files.items() if expire < now
|
||||
]
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
@@ -32,15 +36,35 @@ class FileTokenService:
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
|
||||
# 处理 file:///
|
||||
try:
|
||||
parsed_uri = urlparse(file_path)
|
||||
if parsed_uri.scheme == "file":
|
||||
local_path = unquote(parsed_uri.path)
|
||||
if platform.system() == "Windows" and local_path.startswith("/"):
|
||||
local_path = local_path[1:]
|
||||
else:
|
||||
# 如果没有 file:/// 前缀,则认为是普通路径
|
||||
local_path = file_path
|
||||
except Exception:
|
||||
# 解析失败时,按原路径处理
|
||||
local_path = file_path
|
||||
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
if not os.path.exists(local_path):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在: {local_path} (原始输入: {file_path})"
|
||||
)
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
||||
self.staged_files[file_token] = (file_path, expire_time)
|
||||
expire_time = time.time() + (
|
||||
timeout if timeout is not None else self.default_timeout
|
||||
)
|
||||
# 存储转换后的真实路径
|
||||
self.staged_files[file_token] = (local_path, expire_time)
|
||||
return file_token
|
||||
|
||||
async def handle_file(self, file_token: str) -> str:
|
||||
|
||||
@@ -96,8 +96,6 @@ class LogBroker:
|
||||
Queue: 订阅者的队列, 可用于接收日志消息
|
||||
"""
|
||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||
for log in self.log_cache:
|
||||
q.put_nowait(log)
|
||||
self.subscribers.append(q)
|
||||
return q
|
||||
|
||||
|
||||
183
astrbot/core/persona_mgr.py
Normal file
183
astrbot/core/persona_mgr.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Persona, Personality
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot import logger
|
||||
|
||||
DEFAULT_PERSONALITY = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
begin_dialogs=[],
|
||||
mood_imitation_dialogs=[],
|
||||
tools=None,
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed="",
|
||||
)
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
|
||||
self.db = db_helper
|
||||
self.acm = acm
|
||||
default_ps = acm.default_conf.get("provider_settings", {})
|
||||
self.default_persona: str = default_ps.get("default_personality", "default")
|
||||
self.personas: list[Persona] = []
|
||||
self.selected_default_persona: Persona | None = None
|
||||
|
||||
self.personas_v3: list[Personality] = []
|
||||
self.selected_default_persona_v3: Personality | None = None
|
||||
self.persona_v3_config: list[dict] = []
|
||||
|
||||
async def initialize(self):
|
||||
self.personas = await self.get_all_personas()
|
||||
self.get_v3_persona_data()
|
||||
logger.info(f"已加载 {len(self.personas)} 个人格。")
|
||||
|
||||
async def get_persona(self, persona_id: str):
|
||||
"""获取指定 persona 的信息"""
|
||||
persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
return persona
|
||||
|
||||
async def get_default_persona_v3(
|
||||
self, umo: str | MessageSession | None = None
|
||||
) -> Personality:
|
||||
"""获取默认 persona"""
|
||||
cfg = self.acm.get_conf(umo)
|
||||
default_persona_id = cfg.get("provider_settings", {}).get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
if not default_persona_id or default_persona_id == "default":
|
||||
return DEFAULT_PERSONALITY
|
||||
try:
|
||||
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
|
||||
except Exception:
|
||||
return DEFAULT_PERSONALITY
|
||||
|
||||
async def delete_persona(self, persona_id: str):
|
||||
"""删除指定 persona"""
|
||||
if not await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
await self.db.delete_persona(persona_id)
|
||||
self.personas = [p for p in self.personas if p.persona_id != persona_id]
|
||||
self.get_v3_persona_data()
|
||||
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
):
|
||||
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
existing_persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not existing_persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
persona = await self.db.update_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
if persona:
|
||||
for i, p in enumerate(self.personas):
|
||||
if p.persona_id == persona_id:
|
||||
self.personas[i] = persona
|
||||
break
|
||||
self.get_v3_persona_data()
|
||||
return persona
|
||||
|
||||
async def get_all_personas(self) -> list[Persona]:
|
||||
"""获取所有 personas"""
|
||||
return await self.db.get_personas()
|
||||
|
||||
async def create_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
||||
new_persona = await self.db.insert_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
self.personas.append(new_persona)
|
||||
self.get_v3_persona_data()
|
||||
return new_persona
|
||||
|
||||
def get_v3_persona_data(
|
||||
self,
|
||||
) -> tuple[list[dict], list[Personality], Personality]:
|
||||
"""获取 AstrBot <4.0.0 版本的 persona 数据。
|
||||
|
||||
Returns:
|
||||
- list[dict]: 包含 persona 配置的字典列表。
|
||||
- list[Personality]: 包含 Personality 对象的列表。
|
||||
- Personality: 默认选择的 Personality 对象。
|
||||
"""
|
||||
v3_persona_config = [
|
||||
{
|
||||
"prompt": persona.system_prompt,
|
||||
"name": persona.persona_id,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"mood_imitation_dialogs": [], # deprecated
|
||||
"tools": persona.tools,
|
||||
}
|
||||
for persona in self.personas
|
||||
]
|
||||
|
||||
personas_v3: list[Personality] = []
|
||||
selected_default_persona: Personality | None = None
|
||||
|
||||
for persona_cfg in v3_persona_config:
|
||||
begin_dialogs = persona_cfg.get("begin_dialogs", [])
|
||||
bd_processed = []
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona_cfg,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed="", # deprecated
|
||||
)
|
||||
if persona["name"] == self.default_persona:
|
||||
selected_default_persona = persona
|
||||
personas_v3.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not selected_default_persona and len(personas_v3) > 0:
|
||||
# 默认选择第一个
|
||||
selected_default_persona = personas_v3[0]
|
||||
|
||||
if not selected_default_persona:
|
||||
selected_default_persona = DEFAULT_PERSONALITY
|
||||
personas_v3.append(selected_default_persona)
|
||||
|
||||
self.personas_v3 = personas_v3
|
||||
self.selected_default_persona_v3 = selected_default_persona
|
||||
self.persona_v3_config = v3_persona_config
|
||||
self.selected_default_persona = Persona(
|
||||
persona_id=selected_default_persona["name"],
|
||||
system_prompt=selected_default_persona["prompt"],
|
||||
begin_dialogs=selected_default_persona["begin_dialogs"],
|
||||
tools=selected_default_persona["tools"] or None,
|
||||
)
|
||||
|
||||
return v3_persona_config, personas_v3, selected_default_persona
|
||||
@@ -1,25 +1,25 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
EventResultType,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
@@ -29,9 +29,9 @@ STAGES_ORDER = [
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from .context_utils import call_handler, call_event_hook
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,97 +10,6 @@ class PipelineContext:
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
|
||||
async def call_event_hook(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
) -> bool:
|
||||
"""调用事件钩子函数
|
||||
|
||||
Returns:
|
||||
bool: 如果事件被终止,返回 True
|
||||
"""
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
async def call_handler(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
astrbot_config_id: str
|
||||
call_handler = call_handler
|
||||
call_event_hook = call_event_hook
|
||||
|
||||
98
astrbot/core/pipeline/context_utils.py
Normal file
98
astrbot/core/pipeline/context_utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from astrbot import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
|
||||
async def call_event_hook(
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""调用事件钩子函数
|
||||
|
||||
Returns:
|
||||
bool: 如果事件被终止,返回 True
|
||||
# """
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args, **kwargs)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
|
||||
return event.is_stopped()
|
||||
@@ -1,56 +0,0 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class PlatformCompatibilityStage(Stage):
|
||||
"""检查所有处理器的平台兼容性。
|
||||
|
||||
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化平台兼容性检查阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
|
||||
# 获取已激活的处理器
|
||||
activated_handlers = event.get_extra("activated_handlers")
|
||||
if activated_handlers is None:
|
||||
activated_handlers = []
|
||||
|
||||
# 标记不兼容的处理器
|
||||
for handler in activated_handlers:
|
||||
if not isinstance(handler, StarHandlerMetadata):
|
||||
continue
|
||||
# 检查处理器是否在当前平台启用
|
||||
enabled = handler.is_enabled_for_platform(platform_id)
|
||||
if not enabled:
|
||||
if handler.handler_module_path in star_map:
|
||||
plugin_name = star_map[handler.handler_module_path].name
|
||||
logger.debug(
|
||||
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||
)
|
||||
# 设置处理器为平台不兼容状态
|
||||
# TODO: 更好的标记方式
|
||||
handler.platform_compatible = False
|
||||
else:
|
||||
# 确保处理器为平台兼容状态
|
||||
handler.platform_compatible = True
|
||||
|
||||
# 更新已激活的处理器列表
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
@@ -2,29 +2,288 @@
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import copy
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext, call_event_hook, call_handler
|
||||
from ..stage import Stage
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
try:
|
||||
import mcp
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
if tool.origin == "local":
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif tool.origin == "mcp":
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
raise Exception(f"Unknown function origin: {tool.origin}")
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input", "agent")
|
||||
agent_runner = AgentRunner()
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description,
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=run_context.context.provider,
|
||||
first_provider_request=run_context.context.first_provider_request,
|
||||
curr_provider_request=request,
|
||||
streaming=run_context.context.streaming,
|
||||
)
|
||||
|
||||
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
||||
await run_context.event.send(
|
||||
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
|
||||
)
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=run_context.context.provider,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx, event=run_context.event
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
streaming=run_context.context.streaming,
|
||||
)
|
||||
|
||||
async for _ in run_agent(agent_runner, 15, True):
|
||||
pass
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||
)
|
||||
|
||||
result = (
|
||||
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
|
||||
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
|
||||
)
|
||||
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=result,
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
yield mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
if not run_context.event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run"):
|
||||
raise ValueError("Tool must have a valid handler or 'run' method.")
|
||||
awaitable = tool.handler or getattr(tool, "run")
|
||||
|
||||
wrapper = call_handler(
|
||||
event=run_context.event,
|
||||
handler=awaitable,
|
||||
**tool_args,
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
if not tool.mcp_client:
|
||||
raise ValueError("MCP client is not available for MCP function tools.")
|
||||
res = await tool.mcp_client.session.call_tool(
|
||||
name=tool.name,
|
||||
arguments=tool_args,
|
||||
)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.event, EventType.OnLLMResponseEvent, llm_response
|
||||
)
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
|
||||
) -> AsyncGenerator[MessageChain, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use or astr_event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
else:
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
astr_event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -64,6 +323,20 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent):
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
# 获取对话上下文
|
||||
cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
if not cid:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
return conversation
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, _nested: bool = False
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -72,6 +345,12 @@ class LLMRequestSubStage(Stage):
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
@@ -93,30 +372,14 @@ class LLMRequestSubStage(Stage):
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
# 获取对话上下文
|
||||
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if not conversation_id:
|
||||
conversation_id = await self.conv_manager.new_conversation(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
conversation_id = await self.conv_manager.new_conversation(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, conversation_id
|
||||
)
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
@@ -126,7 +389,7 @@ class LLMRequestSubStage(Stage):
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
@@ -160,88 +423,62 @@ class LLMRequestSubStage(Stage):
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
# Call Agent
|
||||
tool_loop_agent = ToolLoopAgent(
|
||||
provider=provider,
|
||||
event=event,
|
||||
pipeline_ctx=self.ctx,
|
||||
# check provider modalities
|
||||
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
|
||||
req.image_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。")
|
||||
req.func_tool = None
|
||||
# 插件可用性设置
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
plugin = star_map.get(tool.handler_module_path)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=provider,
|
||||
first_provider_request=req,
|
||||
curr_provider_request=req,
|
||||
streaming=self.streaming_response,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=self.streaming_response,
|
||||
)
|
||||
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||
|
||||
async def requesting():
|
||||
step_idx = 0
|
||||
while step_idx < self.max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in tool_loop_agent.step():
|
||||
if event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
if resp.type == "tool_call":
|
||||
if self.streaming_response:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if (
|
||||
self.show_tool_use
|
||||
or event.get_platform_name() == "webchat"
|
||||
):
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not self.streaming_response:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
event.clear_result()
|
||||
else:
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if tool_loop_agent.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=provider.get_model(),
|
||||
provider_type=provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting())
|
||||
.set_async_stream(
|
||||
run_agent(agent_runner, self.max_step, self.show_tool_use)
|
||||
)
|
||||
)
|
||||
yield
|
||||
if tool_loop_agent.done():
|
||||
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
@@ -255,15 +492,15 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
)
|
||||
else:
|
||||
async for _ in requesting():
|
||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||
yield
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
|
||||
|
||||
async def _handle_webchat(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
@@ -296,19 +533,10 @@ class LLMRequestSubStage(Stage):
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
event.unified_msg_origin, title=title
|
||||
unified_msg_origin=event.unified_msg_origin,
|
||||
title=title,
|
||||
conversation_id=req.conversation.cid,
|
||||
)
|
||||
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
|
||||
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
|
||||
# TODO: 优化 WebChat 适配器的对话管理
|
||||
if event.session_id:
|
||||
username, cid = event.session_id.split("!")[1:3]
|
||||
db_helper = self.ctx.plugin_manager.context._db
|
||||
db_helper.update_conversation_title(
|
||||
user_id=username,
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
@@ -324,6 +552,10 @@ class LLMRequestSubStage(Stage):
|
||||
):
|
||||
return
|
||||
|
||||
if not llm_response.completion_text and not req.tool_calls_result:
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||
"""
|
||||
|
||||
from ...context import PipelineContext
|
||||
from ...context import PipelineContext, call_handler
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -33,16 +33,6 @@ class StarRequestSubStage(Stage):
|
||||
handlers_parsed_params = {}
|
||||
|
||||
for handler in activated_handlers:
|
||||
# 检查处理器是否在当前平台兼容
|
||||
if (
|
||||
hasattr(handler, "platform_compatible")
|
||||
and handler.platform_compatible is False
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||
)
|
||||
continue
|
||||
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
@@ -50,7 +40,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
)
|
||||
wrapper = self.ctx.call_handler(event, handler.handler, **params)
|
||||
wrapper = call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -127,7 +128,7 @@ class RespondStage(Stage):
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
logger.info(f"应用流式输出({event.get_platform_id()})")
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
@@ -143,8 +144,6 @@ class RespondStage(Stage):
|
||||
try:
|
||||
if await self._is_empty_message_chain(result.chain):
|
||||
logger.info("消息为空,跳过发送阶段")
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
@@ -177,25 +176,26 @@ class RespondStage(Stage):
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
@@ -214,7 +214,7 @@ class RespondStage(Stage):
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -3,11 +3,12 @@ import time
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
@@ -63,9 +64,10 @@ class ResultDecorateStage(Stage):
|
||||
]
|
||||
self.content_safe_check_stage = None
|
||||
if self.content_safe_check_reply:
|
||||
for stage in registered_stages:
|
||||
if stage.__class__.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage
|
||||
for stage_cls in registered_stages:
|
||||
if stage_cls.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -97,7 +99,7 @@ class ResultDecorateStage(Stage):
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
@@ -176,10 +178,12 @@ class ResultDecorateStage(Stage):
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
|
||||
@@ -11,16 +11,17 @@ class PipelineScheduler:
|
||||
|
||||
def __init__(self, context: PipelineContext):
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||
key=lambda x: STAGES_ORDER.index(x.__name__)
|
||||
) # 按照顺序排序
|
||||
self.ctx = context # 上下文对象
|
||||
self.stages = [] # 存储阶段实例
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管道调度器时, 初始化所有阶段"""
|
||||
for stage in registered_stages:
|
||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
for stage_cls in registered_stages:
|
||||
stage_instance = stage_cls() # 创建实例
|
||||
await stage_instance.initialize(self.ctx)
|
||||
self.stages.append(stage_instance)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
"""依次执行各个阶段
|
||||
@@ -29,9 +30,9 @@ class PipelineScheduler:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||
"""
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
for i in range(from_stage, len(self.stages)):
|
||||
stage = self.stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
|
||||
coroutine = stage.process(
|
||||
event
|
||||
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||
@@ -73,7 +74,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||
if event.get_platform_name() == "webchat":
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
event.stop_event()
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import List, AsyncGenerator, Union
|
||||
from typing import List, AsyncGenerator, Union, Type
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
|
||||
|
||||
|
||||
def register_stage(cls):
|
||||
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
|
||||
registered_stages.append(cls())
|
||||
registered_stages.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -109,8 +112,17 @@ class WakingCheckStage(Stage):
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
|
||||
# 将 plugins_name 设置到 event 中
|
||||
enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"])
|
||||
if enabled_plugins_name == ["*"]:
|
||||
# 如果是 *,则表示所有插件都启用
|
||||
event.plugins_name = None
|
||||
else:
|
||||
event.plugins_name = enabled_plugins_name
|
||||
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
|
||||
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.AdapterMessageEvent
|
||||
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
|
||||
):
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
@@ -166,6 +178,11 @@ class WakingCheckStage(Stage):
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
event, activated_handlers
|
||||
)
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ import asyncio
|
||||
import re
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
Plain,
|
||||
@@ -23,21 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
platform_name: str
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
@staticmethod
|
||||
def from_str(session_str: str):
|
||||
platform_name, message_type, session_id = session_str.split(":")
|
||||
return MessageSesion(platform_name, MessageType(message_type), session_id)
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
@@ -64,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -78,13 +65,23 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.call_llm = False
|
||||
"""是否在此消息事件中禁止默认的 LLM 请求"""
|
||||
|
||||
self.plugins_name: list[str] | None = None
|
||||
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
|
||||
|
||||
# back_compability
|
||||
self.platform = platform_meta
|
||||
|
||||
def get_platform_name(self):
|
||||
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
|
||||
|
||||
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
|
||||
return self.platform_meta.name
|
||||
|
||||
def get_platform_id(self):
|
||||
"""获取这个事件所属的平台的 ID。
|
||||
|
||||
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
|
||||
"""
|
||||
return self.platform_meta.id
|
||||
|
||||
def get_message_str(self) -> str:
|
||||
@@ -188,6 +185,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
清除额外的信息。
|
||||
"""
|
||||
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
|
||||
self._extras.clear()
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
@@ -227,7 +225,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp, gewechat。
|
||||
Fallback仅支持 aiocqhttp。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
@@ -419,7 +417,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
适配情况:
|
||||
|
||||
- gewechat
|
||||
- aiocqhttp(OneBotv11)
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -18,6 +18,9 @@ class PlatformManager:
|
||||
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
这个配置中的 unique_session 需要特殊处理,
|
||||
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
|
||||
self.event_queue = event_queue
|
||||
|
||||
async def initialize(self):
|
||||
@@ -58,10 +61,6 @@ class PlatformManager:
|
||||
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
|
||||
28
astrbot/core/platform/message_session.py
Normal file
28
astrbot/core/platform/message_session.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSession:
|
||||
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
|
||||
如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
|
||||
|
||||
platform_name: str
|
||||
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
platform_id: str = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
def __post_init__(self):
|
||||
self.platform_id = self.platform_name
|
||||
|
||||
@staticmethod
|
||||
def from_str(session_str: str):
|
||||
platform_id, message_type, session_id = session_str.split(":")
|
||||
return MessageSession(platform_id, MessageType(message_type), session_id)
|
||||
|
||||
|
||||
MessageSesion = MessageSession # back compatibility
|
||||
@@ -5,7 +5,7 @@ from asyncio import Queue
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .astr_message_event import MessageSesion
|
||||
from .message_session import MessageSesion
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
@dataclass
|
||||
class PlatformMetadata:
|
||||
name: str
|
||||
"""平台的名称"""
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
@@ -58,50 +58,85 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
@classmethod
|
||||
async def _dispatch_send(
|
||||
cls,
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str,
|
||||
messages: list[dict],
|
||||
):
|
||||
if event:
|
||||
await bot.send(event=event, message=messages)
|
||||
elif is_group:
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
else:
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
|
||||
@classmethod
|
||||
async def send_message(
|
||||
cls,
|
||||
bot: CQHttp,
|
||||
message_chain: MessageChain,
|
||||
event: Event | None = None,
|
||||
is_group: bool = False,
|
||||
session_id: str = None,
|
||||
):
|
||||
"""发送消息"""
|
||||
|
||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||
send_one_by_one = any(
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
|
||||
)
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 合并转发消息
|
||||
|
||||
if isinstance(seg, Node):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
else:
|
||||
payload["user_id"] = self.get_sender_id()
|
||||
await self.bot.call_action(
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
elif isinstance(seg, File):
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
[d],
|
||||
)
|
||||
else:
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
await AiocqhttpMessageEvent._parse_onebot_json(
|
||||
MessageChain([seg])
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if not send_one_by_one:
|
||||
ret = await cls._parse_onebot_json(message_chain)
|
||||
if not ret:
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, ret)
|
||||
return
|
||||
for seg in message_chain.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 合并转发消息
|
||||
if isinstance(seg, Node):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if is_group:
|
||||
payload["group_id"] = session_id
|
||||
await bot.call_action("send_group_forward_msg", **payload)
|
||||
else:
|
||||
payload["user_id"] = session_id
|
||||
await bot.call_action("send_private_forward_msg", **payload)
|
||||
elif isinstance(seg, File):
|
||||
d = await cls._from_segment_to_dict(seg)
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, [d])
|
||||
else:
|
||||
messages = await cls._parse_onebot_json(MessageChain([seg]))
|
||||
if not messages:
|
||||
continue
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, messages)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息"""
|
||||
event = self.message_obj.raw_message
|
||||
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
|
||||
is_group = False
|
||||
if self.get_group_id():
|
||||
is_group = True
|
||||
session_id = self.get_group_id()
|
||||
else:
|
||||
session_id = self.get_sender_id()
|
||||
await self.send_message(
|
||||
bot=self.bot,
|
||||
message_chain=message,
|
||||
event=event,
|
||||
is_group=is_group,
|
||||
session_id=session_id,
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
|
||||
@@ -83,19 +83,18 @@ class AiocqhttpAdapter(Platform):
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
||||
match session.message_type.value:
|
||||
case MessageType.GROUP_MESSAGE.value:
|
||||
if "_" in session.session_id:
|
||||
# 独立会话
|
||||
_, group_id = session.session_id.split("_")
|
||||
await self.bot.send_group_msg(group_id=group_id, message=ret)
|
||||
else:
|
||||
await self.bot.send_group_msg(
|
||||
group_id=session.session_id, message=ret
|
||||
)
|
||||
case MessageType.FRIEND_MESSAGE.value:
|
||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||
is_group = session.message_type == MessageType.GROUP_MESSAGE
|
||||
if is_group:
|
||||
session_id = session.session_id.split("_")[-1]
|
||||
else:
|
||||
session_id = session.session_id
|
||||
await AiocqhttpMessageEvent.send_message(
|
||||
bot=self.bot,
|
||||
message_chain=message_chain,
|
||||
event=None, # 这里不需要 event,因为是通过 session 发送的
|
||||
is_group=is_group,
|
||||
session_id=session_id,
|
||||
)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
@@ -273,8 +272,14 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
|
||||
reply_event_data["post_type"] = "message"
|
||||
new_event = Event.from_payload(reply_event_data)
|
||||
if not new_event:
|
||||
logger.error(
|
||||
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
|
||||
)
|
||||
continue
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
new_event, get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
@@ -307,7 +312,9 @@ class AiocqhttpAdapter(Platform):
|
||||
user_id=int(m["data"]["qq"]),
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
|
||||
nickname = at_info.get("nick", "") or at_info.get(
|
||||
"nickname", ""
|
||||
)
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
|
||||
@@ -26,7 +26,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"AstrBot",
|
||||
segment.text,
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
|
||||
@@ -1,812 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SimpleGewechatClient:
|
||||
"""针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
nickname: str,
|
||||
host: str,
|
||||
port: int,
|
||||
event_queue: asyncio.Queue,
|
||||
):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith("/"):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
|
||||
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_token>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
self.userrealnames = {}
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self.staged_files = {}
|
||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob["data"]
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {"X-GEWE-TOKEN": self.token}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
if "TypeName" in data:
|
||||
type_name = data["TypeName"]
|
||||
elif "type_name" in data:
|
||||
type_name = data["type_name"]
|
||||
else:
|
||||
raise Exception("无法识别的消息类型")
|
||||
|
||||
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||
if type_name == "ModContacts":
|
||||
logger.info("gewechat下发:ModContacts消息通知。")
|
||||
return
|
||||
if type_name == "DelContacts":
|
||||
logger.info("gewechat下发:DelContacts消息通知。")
|
||||
return
|
||||
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
d = None
|
||||
if "Data" in data:
|
||||
d = data["Data"]
|
||||
elif "data" in data:
|
||||
d = data["data"]
|
||||
|
||||
if not d:
|
||||
logger.warning(f"消息不含 data 字段: {data}")
|
||||
return
|
||||
|
||||
if "CreateTime" in d:
|
||||
# 得到系统 UTF+8 的 ts
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
ts = datetime.datetime.now(tz).timestamp()
|
||||
create_time = d["CreateTime"]
|
||||
if create_time < ts - 30:
|
||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
from_user_name = d["FromUserName"]["string"] # 消息来源
|
||||
d["to_wxid"] = from_user_name # 用于发信息
|
||||
|
||||
abm.message_id = str(d.get("MsgId"))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data["Wxid"] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
at_wxids = []
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
at_wxids = re.findall(
|
||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||
msg_source,
|
||||
)
|
||||
|
||||
abm.group_id = from_user_name
|
||||
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
):
|
||||
at_me = True
|
||||
if "在群聊中@了你" in d.get("PushContent", ""):
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||
# if user_id == abm.self_id:
|
||||
# logger.info("忽略自己发送的消息")
|
||||
# return None
|
||||
|
||||
abm.message = []
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
if abm.group_id:
|
||||
if (
|
||||
abm.group_id not in self.userrealnames
|
||||
or user_id not in self.userrealnames[abm.group_id]
|
||||
):
|
||||
# 获取群成员列表,并且缓存
|
||||
if abm.group_id not in self.userrealnames:
|
||||
self.userrealnames[abm.group_id] = {}
|
||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||
if member_list and "memberList" in member_list:
|
||||
for member in member_list["memberList"]:
|
||||
self.userrealnames[abm.group_id][member["wxid"]] = member[
|
||||
"nickName"
|
||||
]
|
||||
if user_id in self.userrealnames[abm.group_id]:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
try:
|
||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||
user_real_name = info["nickName"]
|
||||
except Exception as e:
|
||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||
user_real_name = user_id
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||
for wxid in at_wxids:
|
||||
# 群聊里 At 其他人的列表
|
||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||
abm.message.append(At(qq=wxid, name=_username))
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
|
||||
if user_id == "weixin":
|
||||
# 忽略微信团队消息
|
||||
return
|
||||
|
||||
# 不同消息类型
|
||||
match d["MsgType"]:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid, content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||
case 37: # 好友申请
|
||||
logger.info("消息类型(37):好友申请")
|
||||
case 42: # 名片
|
||||
logger.info("消息类型(42):名片")
|
||||
case 43: # 视频
|
||||
video = Video(file="", cover=content)
|
||||
abm.message.append(video)
|
||||
case 47: # emoji
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
emoji = data_parser.parse_emoji()
|
||||
abm.message.append(emoji)
|
||||
case 48: # 地理位置
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
segments = data_parser.parse_mutil_49()
|
||||
if segments:
|
||||
abm.message.extend(segments)
|
||||
for seg in segments:
|
||||
if isinstance(seg, Plain):
|
||||
abm.message_str += seg.text
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||
logger.info(
|
||||
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||
)
|
||||
|
||||
case _:
|
||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||
abm.raw_message = d
|
||||
|
||||
logger.debug(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def _callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get("testMsg", None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = None
|
||||
try:
|
||||
abm = await self._convert(data)
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。"
|
||||
)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _register_file(self, file_path: str) -> str:
|
||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
Returns:
|
||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||
"""
|
||||
async with self.lock:
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
self.staged_files[file_token] = file_path
|
||||
return file_token
|
||||
|
||||
async def _handle_file(self, file_token):
|
||||
async with self.lock:
|
||||
if file_token not in self.staged_files:
|
||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||
return quart.abort(404)
|
||||
if not os.path.exists(self.staged_files[file_token]):
|
||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||
return quart.abort(404)
|
||||
file_path = self.staged_files[file_token]
|
||||
self.staged_files.pop(file_token, None)
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={"token": self.token, "callbackUrl": self.callback_url},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob["ret"] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(
|
||||
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
|
||||
)
|
||||
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host="0.0.0.0",
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
"""检查 APPID 对应的设备是否在线。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={"appId": appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def logout(self):
|
||||
"""登出 gewechat。"""
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={"appId": self.appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(
|
||||
self.base_url, self.download_base_url, self.token
|
||||
)
|
||||
|
||||
if self.appid:
|
||||
try:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态失败: {e}")
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob["ret"] != 200:
|
||||
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||
if "设备不存在" in error_msg:
|
||||
logger.error(
|
||||
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||
)
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
return await self.login()
|
||||
else:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob["data"]["qrData"]
|
||||
qr_uuid = json_blob["data"]["uuid"]
|
||||
appid = json_blob["data"]["appId"]
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(
|
||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({"uuid": qr_uuid, "appId": appid})
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
|
||||
ret = json_blob["ret"]
|
||||
msg = ""
|
||||
if json_blob["data"] and "msg" in json_blob["data"]:
|
||||
msg = json_blob["data"]["msg"]
|
||||
if ret == 500 and "安全验证码" in msg:
|
||||
logger.warning(
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
if "status" in json_blob["data"]:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||
"""
|
||||
|
||||
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||
"""获取群成员列表。
|
||||
|
||||
Args:
|
||||
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||
|
||||
Returns:
|
||||
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||
"""
|
||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
"""发送纯文本消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
if ats:
|
||||
payload["ats"] = ats
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
"""发送图片消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||
"""发送emoji消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"emojiMd5": emoji_md5,
|
||||
"emojiSize": emoji_size,
|
||||
}
|
||||
|
||||
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||
try:
|
||||
if emoji_md5 != "" and emoji_size != "":
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postEmoji",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(
|
||||
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||
)
|
||||
else:
|
||||
await self.post_image(to_wxid, cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
async def post_video(
|
||||
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"videoUrl": video_url,
|
||||
"thumbUrl": thumb_url,
|
||||
"videoDuration": video_duration,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送视频结果: {json_blob}")
|
||||
|
||||
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||
"""转发视频
|
||||
|
||||
Args:
|
||||
to_wxid (str): 发送给谁
|
||||
cnd_xml (str): 视频消息的cdn信息
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"xml": cnd_xml,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/forwardVideo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"转发视频结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
"""发送语音信息
|
||||
|
||||
Args:
|
||||
voice_url (str): 语音文件的网络链接
|
||||
voice_duration (int): 语音时长,毫秒
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration,
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||
|
||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||
"""发送文件
|
||||
|
||||
Args:
|
||||
to_wxid (string): 微信ID
|
||||
file_url (str): 文件的网络链接
|
||||
file_name (str): 文件名
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"fileUrl": file_url,
|
||||
"fileName": file_name,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送文件结果: {json_blob}")
|
||||
|
||||
async def add_friend(self, v3: str, v4: str, content: str):
|
||||
"""申请添加好友"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"scene": 3,
|
||||
"content": content,
|
||||
"v4": v4,
|
||||
"v3": v3,
|
||||
"option": 2,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/addContacts",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group_member(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def accept_group_invite(self, url: str):
|
||||
"""同意进群"""
|
||||
payload = {"appId": self.appid, "url": url}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/agreeJoinRoom",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def add_group_member_to_friend(
|
||||
self, group_id: str, to_wxid: str, content: str
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
"content": content,
|
||||
"memberWxid": to_wxid,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_user_or_group_info(self, *ids):
|
||||
"""
|
||||
获取用户或群组信息。
|
||||
|
||||
:param ids: 可变数量的 wxid 参数
|
||||
"""
|
||||
|
||||
wxids_str = list(ids)
|
||||
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/getDetailInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_contacts_list(self):
|
||||
"""
|
||||
获取通讯录列表
|
||||
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||
"""
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/fetchContactsList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||
return json_blob
|
||||
@@ -1,55 +0,0 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
class GeweDownloader:
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
"""返回一个可下载的 URL"""
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {"appId": appid, "xml": xml, "type": choice}
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadImage", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
if "fileUrl" in json_blob["data"]:
|
||||
return self.download_base_url + json_blob["data"]["fileUrl"]
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
|
||||
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||
"""下载emoji"""
|
||||
try:
|
||||
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||
|
||||
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadEmojiMd5", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
return json_blob
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download emoji: {e}")
|
||||
@@ -1,264 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
Record,
|
||||
At,
|
||||
File,
|
||||
Video,
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, "rb") as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
elif n_frames == 0:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(
|
||||
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
||||
):
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
# 检查@
|
||||
ats = []
|
||||
ats_names = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, At):
|
||||
ats.append(comp.qq)
|
||||
ats_names.append(comp.name)
|
||||
has_at = False
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text = comp.text
|
||||
payload = {
|
||||
"to_wxid": to_wxid,
|
||||
"content": text,
|
||||
}
|
||||
if not has_at and ats:
|
||||
ats = f"{','.join(ats)}"
|
||||
ats_names = f"@{' @'.join(ats_names)}"
|
||||
text = f"{ats_names} {text}"
|
||||
payload["content"] = text
|
||||
payload["ats"] = ats
|
||||
has_at = True
|
||||
await client.post_text(**payload)
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||
token = await client._register_file(img_path)
|
||||
img_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
if comp.cover != "":
|
||||
await client.forward_video(to_wxid, comp.cover)
|
||||
else:
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
raise ModuleNotFoundError(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
|
||||
video_token = await client._register_file(video_path)
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||
ff.options(command)
|
||||
thumb_token = await client._register_file(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
|
||||
# 创建 FFprobe 实例
|
||||
ffprobe = FFprobe(video_url)
|
||||
# 获取时长字符串
|
||||
duration_str = ffprobe.duration
|
||||
# 处理时长字符串
|
||||
video_duration = float(duration_str.replace(":", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
# 发送视频
|
||||
await client.post_video(
|
||||
to_wxid, video_callback_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时缩略图文件
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
token = await client._register_file(silk_path)
|
||||
record_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
file_path = comp.file
|
||||
file_name = comp.name
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
token = await client._register_file(file_path)
|
||||
file_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_name)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||
await super().send(message)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
# 确定有效的 group_id
|
||||
if group_id is None:
|
||||
group_id = self.get_group_id()
|
||||
|
||||
if not group_id:
|
||||
return None
|
||||
|
||||
res = await self.client.get_group(group_id)
|
||||
data: dict = res["data"]
|
||||
|
||||
if not data["chatroomId"]:
|
||||
return None
|
||||
|
||||
members = [
|
||||
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||
for member in data.get("memberList", [])
|
||||
]
|
||||
|
||||
return Group(
|
||||
group_id=data["chatroomId"],
|
||||
group_name=data.get("nickName"),
|
||||
group_avatar=data.get("smallHeadImgUrl"),
|
||||
group_owner=data.get("chatRoomOwner"),
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
@@ -1,103 +0,0 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||
self.client = None
|
||||
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config["base_url"],
|
||||
self.config["nickname"],
|
||||
self.config["host"],
|
||||
self.config["port"],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
session_id = session.session_id
|
||||
if "#" in session_id:
|
||||
# unique session
|
||||
to_wxid = session_id.split("#")[1]
|
||||
else:
|
||||
to_wxid = session_id
|
||||
|
||||
await GewechatPlatformEvent.send_with_client(
|
||||
message_chain, to_wxid, self.client
|
||||
)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="gewechat",
|
||||
description="基于 gewechat 的 Wechat 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
self.client.shutdown_event.set()
|
||||
try:
|
||||
await self.client.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
await self.client.start_polling()
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss["unique_session"]:
|
||||
message.session_id = message.sender.user_id + "#" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self) -> SimpleGewechatClient:
|
||||
return self.client
|
||||
@@ -1,110 +0,0 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Reply,
|
||||
Plain,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
def __init__(self, data, is_private_chat):
|
||||
self.data = data
|
||||
self.is_private_chat = is_private_chat
|
||||
|
||||
def _format_to_xml(self):
|
||||
return eT.fromstring(self.data)
|
||||
|
||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||
if appmsg_type is None:
|
||||
return
|
||||
|
||||
match appmsg_type.text:
|
||||
case "57":
|
||||
return self.parse_reply()
|
||||
|
||||
def parse_emoji(self) -> Emoji | None:
|
||||
try:
|
||||
emoji_element = self._format_to_xml().find(".//emoji")
|
||||
# 提取 md5 和 len 属性
|
||||
if emoji_element is not None:
|
||||
md5_value = emoji_element.get("md5")
|
||||
emoji_size = emoji_element.get("len")
|
||||
cdnurl = emoji_element.get("cdnurl")
|
||||
|
||||
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||
|
||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||
"""解析引用消息
|
||||
|
||||
Returns:
|
||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||
"""
|
||||
try:
|
||||
replied_id = -1
|
||||
replied_uid = 0
|
||||
replied_nickname = ""
|
||||
replied_content = "" # 被引用者说的内容
|
||||
content = "" # 引用者说的内容
|
||||
|
||||
root = self._format_to_xml()
|
||||
refermsg = root.find(".//refermsg")
|
||||
if refermsg is not None:
|
||||
# 被引用的信息
|
||||
svrid = refermsg.find("svrid")
|
||||
fromusr = refermsg.find("fromusr")
|
||||
displayname = refermsg.find("displayname")
|
||||
refermsg_content = refermsg.find("content")
|
||||
if svrid is not None:
|
||||
replied_id = svrid.text
|
||||
if fromusr is not None:
|
||||
replied_uid = fromusr.text
|
||||
if displayname is not None:
|
||||
replied_nickname = displayname.text
|
||||
if refermsg_content is not None:
|
||||
# 处理引用嵌套,包括嵌套公众号消息
|
||||
if refermsg_content.text.startswith(
|
||||
"<msg>"
|
||||
) or refermsg_content.text.startswith("<?xml"):
|
||||
try:
|
||||
logger.debug("gewechat: Reference message is nested")
|
||||
refer_root = eT.fromstring(refermsg_content.text)
|
||||
img = refer_root.find("img")
|
||||
if img is not None:
|
||||
replied_content = "[图片]"
|
||||
else:
|
||||
app_msg = refer_root.find("appmsg")
|
||||
refermsg_content_title = app_msg.find("title")
|
||||
logger.debug(
|
||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||
)
|
||||
replied_content = refermsg_content_title.text
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: nested failed, {e}")
|
||||
# 处理异常情况
|
||||
replied_content = refermsg_content.text
|
||||
else:
|
||||
replied_content = refermsg_content.text
|
||||
|
||||
# 提取引用者说的内容
|
||||
title = root.find(".//appmsg/title")
|
||||
if title is not None:
|
||||
content = title.text
|
||||
|
||||
reply_seg = Reply(
|
||||
id=replied_id,
|
||||
chain=[Plain(replied_content)],
|
||||
sender_id=replied_uid,
|
||||
sender_nickname=replied_nickname,
|
||||
message_str=replied_content,
|
||||
)
|
||||
plain_seg = Plain(content)
|
||||
return [reply_seg, plain_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||
@@ -3,15 +3,23 @@ import botpy.message
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
import asyncio
|
||||
import base64
|
||||
import aiofiles
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
from botpy.types.message import Media
|
||||
from botpy.types import message
|
||||
from typing import Optional
|
||||
import random
|
||||
import uuid
|
||||
import os
|
||||
|
||||
|
||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
@@ -36,6 +44,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||
ret = None
|
||||
try:
|
||||
async for chain in generator:
|
||||
source = self.message_obj.raw_message
|
||||
@@ -85,9 +94,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
plain_text,
|
||||
image_base64,
|
||||
image_path,
|
||||
record_file_path
|
||||
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||
|
||||
if not plain_text and not image_base64 and not image_path:
|
||||
if not plain_text and not image_base64 and not image_path and not record_file_path:
|
||||
return
|
||||
|
||||
payload = {
|
||||
@@ -98,6 +108,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
|
||||
payload["msg_seq"] = random.randint(1, 10000)
|
||||
|
||||
ret = None
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
if image_base64:
|
||||
@@ -106,6 +118,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
if record_file_path: # group record msg
|
||||
media = await self.upload_group_and_c2c_record(
|
||||
record_file_path, 3, group_openid=source.group_openid
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
ret = await self.bot.api.post_group_message(
|
||||
group_openid=source.group_openid, **payload
|
||||
)
|
||||
@@ -116,6 +134,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
if record_file_path: # c2c record
|
||||
media = await self.upload_group_and_c2c_record(
|
||||
record_file_path, 3, openid = source.author.user_openid
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
if stream:
|
||||
ret = await self.post_c2c_message(
|
||||
openid=source.author.user_openid,
|
||||
@@ -165,6 +189,59 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
async def upload_group_and_c2c_record(
|
||||
self,
|
||||
file_source: str,
|
||||
file_type: int,
|
||||
srv_send_msg: bool = False,
|
||||
**kwargs
|
||||
) -> Optional[Media]:
|
||||
"""
|
||||
上传媒体文件
|
||||
"""
|
||||
# 构建基础payload
|
||||
payload = {
|
||||
"file_type": file_type,
|
||||
"srv_send_msg": srv_send_msg
|
||||
}
|
||||
|
||||
# 处理文件数据
|
||||
if os.path.exists(file_source):
|
||||
# 读取本地文件
|
||||
async with aiofiles.open(file_source, 'rb') as f:
|
||||
file_content = await f.read()
|
||||
# use base64 encode
|
||||
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
|
||||
else:
|
||||
# 使用URL
|
||||
payload["url"] = file_source
|
||||
|
||||
# 添加接收者信息和确定路由
|
||||
if "openid" in kwargs:
|
||||
payload["openid"] = kwargs["openid"]
|
||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||
elif "group_openid" in kwargs:
|
||||
payload["group_openid"] =kwargs["group_openid"]
|
||||
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
|
||||
else:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 使用底层HTTP请求
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if result:
|
||||
return Media(
|
||||
file_uuid=result.get("file_uuid"),
|
||||
file_info=result.get("file_info"),
|
||||
ttl=result.get("ttl", 0),
|
||||
file_id=result.get("id", "")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传请求错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def post_c2c_message(
|
||||
self,
|
||||
openid: str,
|
||||
@@ -191,6 +268,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
plain_text = ""
|
||||
image_base64 = None # only one img supported
|
||||
image_file_path = None
|
||||
record_file_path = None
|
||||
for i in message.chain:
|
||||
if isinstance(i, Plain):
|
||||
plain_text += i.text
|
||||
@@ -206,6 +284,21 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
else:
|
||||
image_base64 = file_to_base64(i.file)
|
||||
image_base64 = image_base64.removeprefix("base64://")
|
||||
elif isinstance(i, Record):
|
||||
if i.file:
|
||||
record_wav_path = await i.convert_to_file_path() # wav 路径
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
|
||||
if duration > 0:
|
||||
record_file_path = record_tecent_silk_path
|
||||
else:
|
||||
record_file_path = None
|
||||
logger.error("转换音频格式时出错:音频时长不大于0")
|
||||
except Exception as e:
|
||||
logger.error(f"处理语音时出错: {e}")
|
||||
record_file_path = None
|
||||
else:
|
||||
logger.debug(f"qq_official 忽略 {i.type}")
|
||||
return plain_text, image_base64, image_file_path
|
||||
return plain_text, image_base64, image_file_path, record_file_path
|
||||
|
||||
@@ -77,7 +77,7 @@ class WebChatAdapter(Platform):
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id", "")
|
||||
name="webchat", description="webchat", id="webchat"
|
||||
)
|
||||
|
||||
async def send_by_session(
|
||||
|
||||
@@ -22,7 +22,11 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
if not message:
|
||||
await web_chat_back_queue.put(
|
||||
{"type": "end", "data": "", "streaming": False}
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
} # end means this request is finished
|
||||
)
|
||||
return ""
|
||||
|
||||
@@ -99,16 +103,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
"cid": cid,
|
||||
}
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
@@ -120,7 +114,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
@@ -134,7 +128,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
|
||||
47
astrbot/core/platform_message_history_mgr.py
Normal file
47
astrbot/core/platform_message_history_mgr.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import PlatformMessageHistory
|
||||
|
||||
|
||||
class PlatformMessageHistoryManager:
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
self.db = db_helper
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict], # TODO: parse from message chain
|
||||
sender_id: str = None,
|
||||
sender_name: str = None,
|
||||
):
|
||||
"""Insert a new platform message history record."""
|
||||
await self.db.insert_platform_message_history(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
sender_id=sender_id,
|
||||
sender_name=sender_name,
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 200,
|
||||
) -> list[PlatformMessageHistory]:
|
||||
"""Get platform message history for a specific user."""
|
||||
history = await self.db.get_platform_message_history(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
history.reverse()
|
||||
return history
|
||||
|
||||
async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400):
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
await self.db.delete_platform_message_offset(
|
||||
platform_id=platform_id, user_id=user_id, offset_sec=offset_sec
|
||||
)
|
||||
@@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot import logger
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
@@ -20,6 +20,7 @@ class ProviderType(enum.Enum):
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,7 +98,7 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall | None = None
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
@@ -293,3 +294,10 @@ class LLMResponse:
|
||||
}
|
||||
)
|
||||
return ret
|
||||
|
||||
@dataclass
|
||||
class RerankResult:
|
||||
index: int
|
||||
"""在候选列表中的索引位置"""
|
||||
relevance_score: float
|
||||
"""相关性分数"""
|
||||
|
||||
@@ -1,32 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import aiohttp
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, List, Awaitable
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
from astrbot.core import sp
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.agent.mcp_client import MCPClient
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
@@ -39,175 +24,109 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
用于描述一个函数调用工具。
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
handler: Awaitable = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
async def execute(self, **args) -> Any:
|
||||
"""执行函数调用"""
|
||||
if self.origin == "local":
|
||||
if not self.handler:
|
||||
raise Exception(f"Local function {self.name} has no handler")
|
||||
return await self.handler(**args)
|
||||
elif self.origin == "mcp":
|
||||
if not self.mcp_client or not self.mcp_client.session:
|
||||
raise Exception(f"MCP client for {self.name} is not available")
|
||||
# 使用name属性而不是额外的mcp_tool_name
|
||||
if ":" in self.name:
|
||||
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||
actual_tool_name = self.name.split(":")[-1]
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
return await self.mcp_client.session.call_tool(self.name, args)
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
# alias
|
||||
FuncTool = FunctionTool
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = mcp_server_config.copy()
|
||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||
cfg = cfg["mcpServers"][key_0]
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
if "url" in cfg:
|
||||
is_sse = True
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
class FuncCall:
|
||||
class FunctionToolManager:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: List[FuncTool] = []
|
||||
"""内部加载的 func tools"""
|
||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_service_queue = asyncio.Queue()
|
||||
"""用于外部控制 MCP 服务的启停"""
|
||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
|
||||
def spec_to_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Awaitable,
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
return FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
@@ -225,22 +144,14 @@ class FuncCall:
|
||||
# check if the tool has been added before
|
||||
self.remove_func(name)
|
||||
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
self.func_list.append(
|
||||
self.spec_to_func(
|
||||
name=name,
|
||||
func_args=func_args,
|
||||
desc=desc,
|
||||
handler=handler,
|
||||
)
|
||||
)
|
||||
self.func_list.append(_func)
|
||||
logger.info(f"添加函数调用工具: {name}")
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
@@ -252,13 +163,17 @@ class FuncCall:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def get_func(self, name) -> FuncTool:
|
||||
def get_func(self, name) -> FuncTool | None:
|
||||
for f in self.func_list:
|
||||
if f.name == name:
|
||||
return f
|
||||
return None
|
||||
|
||||
async def _init_mcp_clients(self) -> None:
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""获取完整工具集"""
|
||||
tool_set = ToolSet(self.func_list.copy())
|
||||
return tool_set
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
```
|
||||
{
|
||||
@@ -300,113 +215,64 @@ class FuncCall:
|
||||
)
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
async def mcp_service_selector(self):
|
||||
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||
|
||||
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||
|
||||
{"type": "init"} 初始化所有MCP客户端
|
||||
|
||||
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||
|
||||
{"type": "terminate"} 终止所有MCP客户端
|
||||
|
||||
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||
"""
|
||||
while True:
|
||||
data = await self.mcp_service_queue.get()
|
||||
if data["type"] == "init":
|
||||
if "name" in data:
|
||||
event = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(
|
||||
data["name"], data["cfg"], event
|
||||
)
|
||||
)
|
||||
self.mcp_client_event[data["name"]] = event
|
||||
else:
|
||||
await self._init_mcp_clients()
|
||||
elif data["type"] == "terminate":
|
||||
if "name" in data:
|
||||
# await self._terminate_mcp_client(data["name"])
|
||||
if data["name"] in self.mcp_client_event:
|
||||
self.mcp_client_event[data["name"]].set()
|
||||
self.mcp_client_event.pop(data["name"], None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (
|
||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
for name in self.mcp_client_dict.keys():
|
||||
# await self._terminate_mcp_client(name)
|
||||
# self.mcp_client_event[name].set()
|
||||
if name in self.mcp_client_event:
|
||||
self.mcp_client_event[name].set()
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self, name: str, cfg: dict, event: asyncio.Event
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
await self._terminate_mcp_client(name)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
finally:
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
"""初始化单个MCP客户端"""
|
||||
try:
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
# 先清理之前的客户端,如果存在
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
logger.debug(f"MCP server {name} list tools response: {tools_res}")
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||
]
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||
]
|
||||
|
||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||
for tool in mcp_client.tools:
|
||||
func_tool = FuncTool(
|
||||
name=tool.name,
|
||||
parameters=tool.inputSchema,
|
||||
description=tool.description,
|
||||
origin="mcp",
|
||||
mcp_server_name=name,
|
||||
mcp_client=mcp_client,
|
||||
)
|
||||
self.func_list.append(func_tool)
|
||||
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||
for tool in mcp_client.tools:
|
||||
func_tool = FuncTool(
|
||||
name=tool.name,
|
||||
parameters=tool.inputSchema,
|
||||
description=tool.description,
|
||||
origin="mcp",
|
||||
mcp_server_name=name,
|
||||
mcp_client=mcp_client,
|
||||
)
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -414,9 +280,9 @@ class FuncCall:
|
||||
try:
|
||||
# 关闭MCP连接
|
||||
await self.mcp_client_dict[name].cleanup()
|
||||
del self.mcp_client_dict[name]
|
||||
self.mcp_client_dict.pop(name)
|
||||
except Exception as e:
|
||||
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
@@ -425,204 +291,273 @@ class FuncCall:
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
if "url" in config:
|
||||
success, error_msg = await _quick_test_mcp_connection(config)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
try:
|
||||
logger.debug(f"testing MCP server connection with config: {config}")
|
||||
await mcp_client.connect_to_server(config, "test")
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
finally:
|
||||
logger.debug("Cleaning up MCP client after testing connection.")
|
||||
await mcp_client.cleanup()
|
||||
return tool_names
|
||||
|
||||
async def enable_mcp_server(
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
"""Enable_mcp_server a new MCP server to the manager and initialize it.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
Raises:
|
||||
TimeoutError: If the initialization does not complete within the specified timeout.
|
||||
Exception: If there is an error during initialization.
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str | None = None, timeout: float = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server by its name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
||||
timeout (int): Timeout.
|
||||
"""
|
||||
if name:
|
||||
if name not in self.mcp_client_event:
|
||||
return
|
||||
client = self.mcp_client_dict.get(name)
|
||||
self.mcp_client_event[name].set()
|
||||
if not client:
|
||||
return
|
||||
client_running_event = client.running_event
|
||||
try:
|
||||
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if f.origin != "mcp" or f.mcp_server_name != name
|
||||
]
|
||||
else:
|
||||
running_events = [
|
||||
client.running_event.wait() for client in self.mcp_client_dict.values()
|
||||
]
|
||||
for key, event in self.mcp_client_event.items():
|
||||
event.set()
|
||||
# waiting for all clients to finish
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
|
||||
finally:
|
||||
self.mcp_client_event.clear()
|
||||
self.mcp_client_dict.clear()
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
_l = []
|
||||
# 处理所有工具(包括本地和MCP工具)
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
func_ = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
# "parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
func_["function"]["parameters"] = f.parameters
|
||||
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||
del func_["function"]["parameters"]
|
||||
_l.append(func_)
|
||||
return _l
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.openai_schema(
|
||||
omit_empty_parameter_field=omit_empty_parameter_field
|
||||
)
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
"""
|
||||
获得 Anthropic API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
# Convert internal format to Anthropic style
|
||||
tool = {
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": f.parameters.get("properties", {}),
|
||||
# Keep the required field from the original parameters if it exists
|
||||
"required": f.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
tools.append(tool)
|
||||
return tools
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.anthropic_schema()
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> dict:
|
||||
"""
|
||||
获得 Google GenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.google_schema()
|
||||
|
||||
# Gemini API 支持的数据类型和格式
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""停用一个已经注册的函数调用工具。
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""转换 schema 为 Gemini API 格式"""
|
||||
Returns:
|
||||
如果没找到,会返回 False"""
|
||||
func_tool = self.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
# 暂时指定默认为null
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties: # 只在有非空属性时添加
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
**({"parameters": convert_schema(f.parameters)}),
|
||||
}
|
||||
for f in self.func_list
|
||||
if f.active
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
inactivated_llm_tools: list = sp.get(
|
||||
"inactivated_llm_tools", [], scope="global", scope_id="global"
|
||||
)
|
||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
inactivated_llm_tools,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
|
||||
return True
|
||||
return False
|
||||
|
||||
TOOLS:
|
||||
可用的函数列表:
|
||||
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
|
||||
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
|
||||
func_tool = self.get_func(name)
|
||||
if func_tool is not None:
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(
|
||||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
||||
)
|
||||
|
||||
{func_definition}
|
||||
func_tool.active = True
|
||||
|
||||
LIMIT:
|
||||
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
|
||||
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
|
||||
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
|
||||
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
|
||||
inactivated_llm_tools: list = sp.get(
|
||||
"inactivated_llm_tools", [], scope="global", scope_id="global"
|
||||
)
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
inactivated_llm_tools,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
EXAMPLE:
|
||||
1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
|
||||
return True
|
||||
return False
|
||||
|
||||
用户的提问是:{question}
|
||||
""")
|
||||
@property
|
||||
def mcp_config_path(self):
|
||||
data_dir = get_astrbot_data_path()
|
||||
return os.path.join(data_dir, "mcp_server.json")
|
||||
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
if res.find("```") != -1:
|
||||
res = res[res.find("```json") + 7 : res.rfind("```")]
|
||||
res = json.loads(res)
|
||||
break
|
||||
except Exception as e:
|
||||
_c += 1
|
||||
if _c == 3:
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise e
|
||||
def load_mcp_config(self):
|
||||
if not os.path.exists(self.mcp_config_path):
|
||||
# 配置文件不存在,创建默认配置
|
||||
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
|
||||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
if "res" in res and not res["res"]:
|
||||
return "", False
|
||||
try:
|
||||
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 MCP 配置失败: {e}")
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
tool_call_result = []
|
||||
for tool in res:
|
||||
# 说明有函数调用
|
||||
func_name = tool["name"]
|
||||
args = tool["args"]
|
||||
# 调用函数
|
||||
func_tool = self.get_func(func_name)
|
||||
if not func_tool:
|
||||
raise Exception(f"Request function {func_name} not found.")
|
||||
def save_mcp_config(self, config: dict):
|
||||
try:
|
||||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存 MCP 配置失败: {e}")
|
||||
return False
|
||||
|
||||
ret = await func_tool.execute(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
async def sync_modelscope_mcp_servers(self, access_token: str) -> None:
|
||||
"""从 ModelScope 平台同步 MCP 服务器配置"""
|
||||
base_url = "https://www.modelscope.cn/openapi/v1"
|
||||
url = f"{base_url}/mcp/servers/operational"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token.strip()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
mcp_server_list = data.get("data", {}).get(
|
||||
"mcp_server_list", []
|
||||
)
|
||||
local_mcp_config = self.load_mcp_config()
|
||||
|
||||
synced_count = 0
|
||||
for server in mcp_server_list:
|
||||
server_name = server["name"]
|
||||
operational_urls = server.get("operational_urls", [])
|
||||
if not operational_urls:
|
||||
continue
|
||||
url_info = operational_urls[0]
|
||||
server_url = url_info.get("url")
|
||||
if not server_url:
|
||||
continue
|
||||
# 添加到配置中(同名会覆盖)
|
||||
local_mcp_config["mcpServers"][server_name] = {
|
||||
"url": server_url,
|
||||
"transport": "sse",
|
||||
"active": True,
|
||||
"provider": "modelscope",
|
||||
}
|
||||
synced_count += 1
|
||||
|
||||
if synced_count > 0:
|
||||
self.save_mcp_config(local_mcp_config)
|
||||
tasks = []
|
||||
for server in mcp_server_list:
|
||||
name = server["name"]
|
||||
tasks.append(
|
||||
self.enable_mcp_server(
|
||||
name=name,
|
||||
config=local_mcp_config["mcpServers"][name],
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(
|
||||
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器"
|
||||
)
|
||||
else:
|
||||
logger.warning("没有找到可用的 ModelScope MCP 服务器")
|
||||
else:
|
||||
raise Exception(
|
||||
f"ModelScope API 请求失败: HTTP {response.status}"
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"网络连接错误: {str(e)}")
|
||||
except Exception as e:
|
||||
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {str(e)}")
|
||||
|
||||
def __str__(self):
|
||||
return str(self.func_list)
|
||||
@@ -630,7 +565,6 @@ class FuncCall:
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
async def terminate(self):
|
||||
for name in self.mcp_client_dict.keys():
|
||||
await self._terminate_mcp_client(name)
|
||||
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
||||
|
||||
# alias
|
||||
FuncCall = FunctionToolManager
|
||||
|
||||
@@ -3,89 +3,32 @@ import traceback
|
||||
from typing import List
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider
|
||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
from ..persona_mgr import PersonaManager
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
acm: AstrBotConfigManager,
|
||||
db_helper: BaseDatabase,
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: List = config["provider"]
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
for persona in self.persona_configs:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
bd_processed = []
|
||||
mid_processed = ""
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
mood_imitation_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in mood_imitation_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
mid_processed += "\n"
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed=mid_processed,
|
||||
)
|
||||
if persona["name"] == self.default_persona_name:
|
||||
self.selected_default_persona = persona
|
||||
self.personas.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
if not self.selected_default_persona:
|
||||
self.selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed="",
|
||||
)
|
||||
self.personas.append(self.selected_default_persona)
|
||||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||||
self.default_persona_name = persona_mgr.default_persona
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
"""加载的 Provider 的实例"""
|
||||
@@ -93,53 +36,118 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map: dict[str, Provider] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
self.curr_provider_inst: Provider | None = None
|
||||
"""默认的 Provider 实例"""
|
||||
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.curr_stt_provider_inst: STTProvider | None = None
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.curr_tts_provider_inst: TTSProvider | None = None
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
self.curr_kdb_name = ""
|
||||
kdb_cfg = config.get("knowledge_db", {})
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
@property
|
||||
def persona_configs(self) -> list:
|
||||
"""动态获取最新的 persona 配置"""
|
||||
return self.persona_mgr.persona_v3_config
|
||||
|
||||
@property
|
||||
def personas(self) -> list:
|
||||
"""动态获取最新的 personas 列表"""
|
||||
return self.persona_mgr.personas_v3
|
||||
|
||||
@property
|
||||
def selected_default_persona(self):
|
||||
"""动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()"""
|
||||
return self.persona_mgr.selected_default_persona_v3
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str | None = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||||
|
||||
Version 4.0.0: 这个版本下已经默认隔离提供商
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo and self.provider_settings["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
if umo:
|
||||
await sp.session_put(
|
||||
umo,
|
||||
f"provider_perf_{provider_type.value}",
|
||||
provider_id,
|
||||
)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||||
|
||||
Returns:
|
||||
Provider: 正在使用的提供商实例。
|
||||
"""
|
||||
provider = None
|
||||
if umo:
|
||||
provider_id = sp.get(
|
||||
f"provider_perf_{provider_type.value}",
|
||||
None,
|
||||
scope="umo",
|
||||
scope_id=umo,
|
||||
)
|
||||
if provider_id:
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
# default setting
|
||||
config = self.acm.get_conf(umo)
|
||||
if provider_type == ProviderType.CHAT_COMPLETION:
|
||||
provider_id = config["provider_settings"].get("default_provider_id")
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = self.provider_insts[0] if self.provider_insts else None
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
provider_id = config["provider_stt_settings"].get("provider_id")
|
||||
if not provider_id:
|
||||
return None
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = (
|
||||
self.stt_provider_insts[0] if self.stt_provider_insts else None
|
||||
)
|
||||
elif provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
provider_id = config["provider_tts_settings"].get("provider_id")
|
||||
if not provider_id:
|
||||
return None
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = (
|
||||
self.tts_provider_insts[0] if self.tts_provider_insts else None
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
return provider
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
@@ -148,13 +156,22 @@ class ProviderManager:
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider", self.provider_settings.get("default_provider_id")
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_stt_provider_id = sp.get(
|
||||
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
|
||||
"curr_provider_stt",
|
||||
self.provider_stt_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_tts_provider_id = sp.get(
|
||||
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
|
||||
"curr_provider_tts",
|
||||
self.provider_tts_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
@@ -169,10 +186,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
||||
)
|
||||
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
if not provider_config["enable"]:
|
||||
@@ -265,6 +279,10 @@ class ProviderManager:
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -348,7 +366,7 @@ class ProviderManager:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
@@ -422,7 +440,7 @@ class ProviderManager:
|
||||
self.curr_tts_provider_inst = None
|
||||
|
||||
if getattr(self.inst_map[provider_id], "terminate", None):
|
||||
await self.inst_map[provider_id].terminate() # type: ignore
|
||||
await self.inst_map[provider_id].terminate() # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
||||
@@ -432,6 +450,8 @@ class ProviderManager:
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
await provider_inst.terminate() # type: ignore
|
||||
# 清理 MCP Client 连接
|
||||
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
||||
await provider_inst.terminate() # type: ignore
|
||||
try:
|
||||
await self.llm_tools.disable_mcp_server()
|
||||
except Exception:
|
||||
logger.error("Error while disabling MCP servers", exc_info=True)
|
||||
|
||||
@@ -1,27 +1,24 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ToolCallsResult,
|
||||
ProviderType,
|
||||
RerankResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.db.po import Personality
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: List[str] = []
|
||||
mood_imitation_dialogs: List[str] = []
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: List[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta:
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
provider_type: ProviderType
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
@@ -40,10 +37,14 @@ class AbstractProvider(abc.ABC):
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
"""获取 Provider 的元数据"""
|
||||
provider_type_name = self.provider_config["type"]
|
||||
meta_data = provider_cls_map.get(provider_type_name)
|
||||
provider_type = meta_data.provider_type if meta_data else None
|
||||
return ProviderMeta(
|
||||
id=self.provider_config["id"],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config["type"],
|
||||
type=provider_type_name,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,7 +85,7 @@ class Provider(AbstractProvider):
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
func_tool: ToolSet = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
@@ -113,7 +114,7 @@ class Provider(AbstractProvider):
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
func_tool: ToolSet = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
@@ -200,3 +201,17 @@ class EmbeddingProvider(AbstractProvider):
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
|
||||
class RerankProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def rerank(
|
||||
self, query: str, documents: list[str], top_n: int | None = None
|
||||
) -> list[RerankResult]:
|
||||
"""获取查询和文档的重排序分数"""
|
||||
...
|
||||
|
||||
@@ -75,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
|
||||
@@ -97,9 +97,9 @@ class ProviderDify(Provider):
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import uuid
|
||||
import re
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
from httpx import AsyncClient
|
||||
@@ -24,8 +25,8 @@ class ServeTTSRequest(BaseModel):
|
||||
# 参考音频
|
||||
references: list[ServeReferenceAudio] = []
|
||||
# 参考模型 ID
|
||||
# 例如 https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
||||
# 其中reference_id为 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||
# 例如 https://fish.audio/m/626bb6d3f3364c9cbc3aa6a67300a664/
|
||||
# 其中reference_id为 626bb6d3f3364c9cbc3aa6a67300a664
|
||||
reference_id: str | None = None
|
||||
# 对中英文文本进行标准化,这可以提高数字的稳定性
|
||||
normalize: bool = True
|
||||
@@ -44,6 +45,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.reference_id: str = provider_config.get("fishaudio-tts-reference-id", "")
|
||||
self.character: str = provider_config.get("fishaudio-tts-character", "可莉")
|
||||
self.api_base: str = provider_config.get(
|
||||
"api_base", "https://api.fish-audio.cn/v1"
|
||||
@@ -81,11 +83,43 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
return item["_id"]
|
||||
return None
|
||||
|
||||
def _validate_reference_id(self, reference_id: str) -> bool:
|
||||
"""
|
||||
验证reference_id格式是否有效
|
||||
|
||||
Args:
|
||||
reference_id: 参考模型ID
|
||||
|
||||
Returns:
|
||||
bool: ID是否有效
|
||||
"""
|
||||
if not reference_id or not reference_id.strip():
|
||||
return False
|
||||
|
||||
# FishAudio的reference_id通常是32位十六进制字符串
|
||||
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
|
||||
pattern = r'^[a-fA-F0-9]{32}$'
|
||||
return bool(re.match(pattern, reference_id.strip()))
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
||||
if self.reference_id and self.reference_id.strip():
|
||||
# 验证reference_id格式
|
||||
if not self._validate_reference_id(self.reference_id):
|
||||
raise ValueError(
|
||||
f"无效的FishAudio参考模型ID: '{self.reference_id}'. "
|
||||
f"请确保ID是32位十六进制字符串(例如: 626bb6d3f3364c9cbc3aa6a67300a664)。"
|
||||
f"您可以从 https://fish.audio/zh-CN/discovery 获取有效的模型ID。"
|
||||
)
|
||||
reference_id = self.reference_id.strip()
|
||||
else:
|
||||
# 回退到原来的角色名称查询逻辑
|
||||
reference_id = await self._get_reference_id_by_character(self.character)
|
||||
|
||||
return ServeTTSRequest(
|
||||
text=text,
|
||||
format="wav",
|
||||
reference_id=await self._get_reference_id_by_character(self.character),
|
||||
reference_id=reference_id,
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
|
||||
@@ -431,6 +431,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
continue
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.raw_completion = result
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
return llm_response
|
||||
|
||||
@@ -470,6 +471,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
raise
|
||||
continue
|
||||
|
||||
# Accumulate the complete response text for the final response
|
||||
accumulated_text = ""
|
||||
final_response = None
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
@@ -477,27 +482,43 @@ class ProviderGoogleGenAI(Provider):
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.raw_completion = chunk
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
return
|
||||
|
||||
if chunk.text:
|
||||
accumulated_text += chunk.text
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
if not chunk.candidates[0].content.parts:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||
else:
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
# Process the final chunk for potential tool calls or other content
|
||||
if chunk.candidates[0].content.parts:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
final_response.raw_completion = chunk
|
||||
final_response.result_chain = self._process_content_parts(
|
||||
chunk, final_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
# Yield final complete response with accumulated text
|
||||
if not final_response:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
|
||||
# Set the complete accumulated text in the final response
|
||||
if accumulated_text:
|
||||
final_response.result_chain = MessageChain(
|
||||
chain=[Comp.Plain(accumulated_text)]
|
||||
)
|
||||
elif not final_response.result_chain:
|
||||
# If no text was accumulated and no final response was set, provide empty space
|
||||
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||
|
||||
yield final_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1024)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
|
||||
@@ -99,6 +99,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for key in to_del:
|
||||
del payloads[key]
|
||||
|
||||
model = payloads.get("model", "")
|
||||
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
|
||||
if "qwen3" in model.lower():
|
||||
extra_body["enable_thinking"] = False
|
||||
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
||||
elif model == "deepseek-reasoner" and "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
@@ -176,7 +184,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
if choice.message.content:
|
||||
if choice.message.content is not None:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
@@ -187,6 +195,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
func_name_ls = []
|
||||
tool_call_ids = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
# workaround for #1454
|
||||
@@ -207,7 +218,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
|
||||
)
|
||||
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
@@ -482,13 +493,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"""
|
||||
new_contexts = []
|
||||
|
||||
flag = False
|
||||
for context in contexts:
|
||||
if flag:
|
||||
flag = False # 删除 image 后,下一条(LLM 响应)也要删除
|
||||
continue
|
||||
if "content" in context and isinstance(context["content"], list):
|
||||
flag = True
|
||||
# continue
|
||||
new_content = []
|
||||
for item in context["content"]:
|
||||
|
||||
59
astrbot/core/provider/sources/vllm_rerank_source.py
Normal file
59
astrbot/core/provider/sources/vllm_rerank_source.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import aiohttp
|
||||
from ..provider import RerankProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType, RerankResult
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"vllm_rerank",
|
||||
"VLLM Rerank 适配器",
|
||||
provider_type=ProviderType.RERANK,
|
||||
)
|
||||
class VLLMRerankProvider(RerankProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
self.auth_key = provider_config.get("rerank_api_key", "")
|
||||
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
|
||||
self.base_url = self.base_url.rstrip("/")
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
self.model = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
|
||||
|
||||
h = {}
|
||||
if self.auth_key:
|
||||
h["Authorization"] = f"Bearer {self.auth_key}"
|
||||
self.client = aiohttp.ClientSession(
|
||||
headers=h,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
)
|
||||
|
||||
async def rerank(
|
||||
self, query: str, documents: list[str], top_n: int | None = None
|
||||
) -> list[RerankResult]:
|
||||
payload = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"model": self.model,
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
async with self.client.post(
|
||||
f"{self.base_url}/v1/rerank", json=payload
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
results = response_data.get("results", [])
|
||||
|
||||
return [
|
||||
RerankResult(
|
||||
index=result["index"],
|
||||
relevance_score=result["relevance_score"],
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""关闭客户端会话"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
@@ -1,4 +1,4 @@
|
||||
from .star import StarMetadata
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
from .star_manager import PluginManager
|
||||
from .context import Context
|
||||
from astrbot.core.provider import Provider
|
||||
@@ -10,25 +10,48 @@ from astrbot.core.star.star_tools import StarTools
|
||||
class Star(CommandParserMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
def __init__(self, context: Context, config: dict | None = None):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not star_map.get(cls.__module__):
|
||||
metadata = StarMetadata(
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__,
|
||||
)
|
||||
star_map[cls.__module__] = metadata
|
||||
star_registry.append(metadata)
|
||||
else:
|
||||
star_map[cls.__module__].star_cls_type = cls
|
||||
star_map[cls.__module__].module_path = cls.__module__
|
||||
|
||||
@staticmethod
|
||||
async def text_to_image(text: str, return_url=True) -> str:
|
||||
"""将文本转换为图片"""
|
||||
return await html_renderer.render_t2i(text, return_url=return_url)
|
||||
|
||||
@staticmethod
|
||||
async def html_render(
|
||||
self, tmpl: str, data: dict, return_url=True, options: dict = None
|
||||
tmpl: str, data: dict, return_url=True, options: dict | None = None
|
||||
) -> str:
|
||||
"""渲染 HTML"""
|
||||
return await html_renderer.render_custom_template(
|
||||
tmpl, data, return_url=return_url, options=options
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""当插件被激活时会调用这个方法"""
|
||||
pass
|
||||
|
||||
async def terminate(self):
|
||||
"""当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider,
|
||||
TTSProvider,
|
||||
STTProvider,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
@@ -22,6 +29,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
ADAPTER_NAME_2_TYPE,
|
||||
)
|
||||
from deprecated import deprecated
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -29,19 +37,6 @@ class Context:
|
||||
暴露给插件的接口上下文。
|
||||
"""
|
||||
|
||||
_event_queue: Queue = None
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
|
||||
_config: AstrBotConfig = None
|
||||
"""AstrBot 配置信息"""
|
||||
|
||||
_db: BaseDatabase = None
|
||||
"""AstrBot 数据库"""
|
||||
|
||||
provider_manager: ProviderManager = None
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
@@ -53,18 +48,27 @@ class Context:
|
||||
event_queue: Queue,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
provider_manager: ProviderManager,
|
||||
platform_manager: PlatformManager,
|
||||
conversation_manager: ConversationManager,
|
||||
message_history_manager: PlatformMessageHistoryManager,
|
||||
persona_manager: PersonaManager,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
self._config = config
|
||||
"""AstrBot 默认配置"""
|
||||
self._db = db
|
||||
"""AstrBot 数据库"""
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
self.message_history_manager = message_history_manager
|
||||
self.persona_manager = persona_manager
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
||||
"""根据插件名获取插件的 Metadata"""
|
||||
for star in star_registry:
|
||||
if star.name == star_name:
|
||||
@@ -74,7 +78,7 @@ class Context:
|
||||
"""获取当前载入的所有插件 Metadata 的列表"""
|
||||
return star_registry
|
||||
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
def get_llm_tool_manager(self) -> FunctionToolManager:
|
||||
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
@@ -84,40 +88,14 @@ class Context:
|
||||
Returns:
|
||||
如果没找到,会返回 False
|
||||
"""
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(
|
||||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
||||
)
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""停用一个已经注册的函数调用工具。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False"""
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""
|
||||
@@ -125,7 +103,7 @@ class Context:
|
||||
"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
|
||||
@@ -141,51 +119,53 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str | None = None) -> Provider | None:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
||||
"""获取 AstrBot 的配置。"""
|
||||
return self._config
|
||||
if not umo:
|
||||
# using default config
|
||||
return self._config
|
||||
else:
|
||||
return self.astrbot_config_mgr.get_conf(umo)
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
@@ -197,9 +177,14 @@ class Context:
|
||||
"""
|
||||
return self._event_queue
|
||||
|
||||
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform:
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(
|
||||
self, platform_type: Union[PlatformAdapterType, str]
|
||||
) -> Platform | None:
|
||||
"""
|
||||
获取指定类型的平台适配器。
|
||||
|
||||
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
name = platform.meta().name
|
||||
@@ -213,6 +198,20 @@ class Context:
|
||||
):
|
||||
return platform
|
||||
|
||||
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
||||
"""
|
||||
获取指定 ID 的平台适配器实例。
|
||||
|
||||
Args:
|
||||
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
||||
|
||||
Returns:
|
||||
Platform: 平台适配器实例,如果未找到则返回 None。
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().id == platform_id:
|
||||
return platform
|
||||
|
||||
async def send_message(
|
||||
self, session: Union[str, MessageSesion], message_chain: MessageChain
|
||||
) -> bool:
|
||||
@@ -236,7 +235,7 @@ class Context:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
if platform.meta().id == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -8,22 +8,45 @@ from typing import Union
|
||||
class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
VCHAT = enum.auto()
|
||||
GEWECHAT = enum.auto()
|
||||
TELEGRAM = enum.auto()
|
||||
WECOM = enum.auto()
|
||||
LARK = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT | TELEGRAM | WECOM | LARK
|
||||
WECHATPADPRO = enum.auto()
|
||||
DINGTALK = enum.auto()
|
||||
DISCORD = enum.auto()
|
||||
SLACK = enum.auto()
|
||||
KOOK = enum.auto()
|
||||
VOCECHAT = enum.auto()
|
||||
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||
ALL = (
|
||||
AIOCQHTTP
|
||||
| QQOFFICIAL
|
||||
| TELEGRAM
|
||||
| WECOM
|
||||
| LARK
|
||||
| WECHATPADPRO
|
||||
| DINGTALK
|
||||
| DISCORD
|
||||
| SLACK
|
||||
| KOOK
|
||||
| VOCECHAT
|
||||
| WEIXIN_OFFICIAL_ACCOUNT
|
||||
)
|
||||
|
||||
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"vchat": PlatformAdapterType.VCHAT,
|
||||
"gewechat": PlatformAdapterType.GEWECHAT,
|
||||
"telegram": PlatformAdapterType.TELEGRAM,
|
||||
"wecom": PlatformAdapterType.WECOM,
|
||||
"lark": PlatformAdapterType.LARK,
|
||||
"dingtalk": PlatformAdapterType.DINGTALK,
|
||||
"discord": PlatformAdapterType.DISCORD,
|
||||
"slack": PlatformAdapterType.SLACK,
|
||||
"kook": PlatformAdapterType.KOOK,
|
||||
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
||||
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from .star_handler import (
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
register_agent,
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent,
|
||||
)
|
||||
@@ -28,6 +29,7 @@ __all__ = [
|
||||
"register_on_llm_request",
|
||||
"register_on_llm_response",
|
||||
"register_llm_tool",
|
||||
"register_agent",
|
||||
"register_on_decorating_result",
|
||||
"register_after_message_sent",
|
||||
]
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from ..star import star_registry, StarMetadata, star_map
|
||||
import warnings
|
||||
|
||||
from astrbot.core.star import StarMetadata, star_map
|
||||
|
||||
_warned_register_star = False
|
||||
|
||||
|
||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
||||
"""注册一个插件(Star)。
|
||||
|
||||
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||
在 v3.5.19 版本之后(不含),您不需要使用该装饰器来装饰插件类,
|
||||
AstrBot 会自动识别继承自 Star 的类并将其作为插件类加载。
|
||||
|
||||
Args:
|
||||
name: 插件名称。
|
||||
author: 作者。
|
||||
@@ -21,18 +29,32 @@ def register_star(name: str, author: str, desc: str, version: str, repo: str = N
|
||||
帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。`
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
star_metadata = StarMetadata(
|
||||
name=name,
|
||||
author=author,
|
||||
desc=desc,
|
||||
version=version,
|
||||
repo=repo,
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__,
|
||||
global _warned_register_star
|
||||
if not _warned_register_star:
|
||||
_warned_register_star = True
|
||||
warnings.warn(
|
||||
"The 'register_star' decorator is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
star_registry.append(star_metadata)
|
||||
star_map[cls.__module__] = star_metadata
|
||||
|
||||
def decorator(cls):
|
||||
if not star_map.get(cls.__module__):
|
||||
metadata = StarMetadata(
|
||||
name=name,
|
||||
author=author,
|
||||
desc=desc,
|
||||
version=version,
|
||||
repo=repo,
|
||||
)
|
||||
star_map[cls.__module__] = metadata
|
||||
else:
|
||||
star_map[cls.__module__].name = name
|
||||
star_map[cls.__module__].author = author
|
||||
star_map[cls.__module__].desc = desc
|
||||
star_map[cls.__module__].version = version
|
||||
star_map[cls.__module__].repo = repo
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -15,6 +15,11 @@ from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.agent.agent import Agent
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
|
||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||
@@ -306,7 +311,7 @@ def register_on_llm_response(**kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
def register_llm_tool(name: str = None, **kwargs):
|
||||
"""为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
@@ -340,6 +345,9 @@ def register_llm_tool(name: str = None):
|
||||
"""
|
||||
|
||||
name_ = name
|
||||
registering_agent = None
|
||||
if kwargs.get("registering_agent"):
|
||||
registering_agent = kwargs["registering_agent"]
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
@@ -357,15 +365,69 @@ def register_llm_tool(name: str = None):
|
||||
"description": arg.description,
|
||||
}
|
||||
)
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(
|
||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||
)
|
||||
# print(llm_tool_name, registering_agent)
|
||||
if not registering_agent:
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(
|
||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||
)
|
||||
else:
|
||||
assert isinstance(registering_agent, RegisteringAgent)
|
||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||
if registering_agent._agent.tools is None:
|
||||
registering_agent._agent.tools = []
|
||||
registering_agent._agent.tools.append(llm_tools.spec_to_func(
|
||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
||||
))
|
||||
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RegisteringAgent:
|
||||
"""用于 Agent 注册"""
|
||||
|
||||
def llm_tool(self, *args, **kwargs):
|
||||
kwargs["registering_agent"] = self
|
||||
return register_llm_tool(*args, **kwargs)
|
||||
|
||||
def __init__(self, agent: Agent[AstrAgentContext]):
|
||||
self._agent = agent
|
||||
|
||||
|
||||
def register_agent(
|
||||
name: str,
|
||||
instruction: str,
|
||||
tools: list[str | FunctionTool] = None,
|
||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
|
||||
):
|
||||
"""注册一个 Agent
|
||||
|
||||
Args:
|
||||
name: Agent 的名称
|
||||
instruction: Agent 的指令
|
||||
tools: Agent 使用的工具列表
|
||||
run_hooks: Agent 运行时的钩子函数
|
||||
"""
|
||||
tools_ = tools or []
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
AstrAgent = Agent[AstrAgentContext]
|
||||
agent = AstrAgent(
|
||||
name=name,
|
||||
instructions=instruction,
|
||||
tools=tools_,
|
||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
)
|
||||
handoff_tool = HandoffTool(agent=agent)
|
||||
handoff_tool.handler=awaitable
|
||||
llm_tools.func_list.append(handoff_tool)
|
||||
return RegisteringAgent(agent)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_decorating_result(**kwargs):
|
||||
"""在发送消息前的事件"""
|
||||
|
||||
|
||||
250
astrbot/core/star/session_llm_manager.py
Normal file
250
astrbot/core/star/session_llm_manager.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||
"""
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionServiceManager:
|
||||
"""管理会话级别的服务启停状态,包括LLM和TTS"""
|
||||
|
||||
# =============================================================================
|
||||
# LLM 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_services.get("llm_enabled")
|
||||
if llm_enabled is not None:
|
||||
return llm_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["llm_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# TTS 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查TTS是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
tts_enabled = session_services.get("tts_enabled")
|
||||
if tts_enabled is not None:
|
||||
return tts_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置TTS在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["tts_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理TTS请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话整体启停相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_session_enabled(session_id: str) -> bool:
|
||||
"""检查会话是否整体启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
session_enabled = session_services.get("session_enabled")
|
||||
if session_enabled is not None:
|
||||
return session_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_session_status(session_id: str, enabled: bool) -> None:
|
||||
"""设置会话的整体启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["session_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理会话请求(会话整体启停检查)
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_session_enabled(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话命名相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str | None:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
"""
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
||||
"""设置会话的自定义名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
if custom_name and custom_name.strip():
|
||||
session_config["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config.pop("custom_name", None)
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_display_name(session_id: str) -> str:
|
||||
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
str: 显示名称
|
||||
"""
|
||||
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||
if custom_name:
|
||||
return custom_name
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
150
astrbot/core/star/session_plugin_manager.py
Normal file
150
astrbot/core/star/session_plugin_manager.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
会话插件管理器 - 负责管理每个会话的插件启停状态
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict, List
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionPluginManager:
|
||||
"""管理会话级别的插件启停状态"""
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||
"""检查插件是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
# 如果插件在禁用列表中,返回False
|
||||
if plugin_name in disabled_plugins:
|
||||
return False
|
||||
|
||||
# 如果插件在启用列表中,返回True
|
||||
if plugin_name in enabled_plugins:
|
||||
return True
|
||||
|
||||
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_plugin_status_for_session(
|
||||
session_id: str, plugin_name: str, enabled: bool
|
||||
) -> None:
|
||||
"""设置插件在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": [],
|
||||
}
|
||||
|
||||
session_config = session_plugin_config[session_id]
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
if enabled:
|
||||
# 启用插件
|
||||
if plugin_name in disabled_plugins:
|
||||
disabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in enabled_plugins:
|
||||
enabled_plugins.append(plugin_name)
|
||||
else:
|
||||
# 禁用插件
|
||||
if plugin_name in enabled_plugins:
|
||||
enabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in disabled_plugins:
|
||||
disabled_plugins.append(plugin_name)
|
||||
|
||||
# 保存配置
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put(
|
||||
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]:
|
||||
"""获取指定会话的插件配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
"""
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
return session_plugin_config.get(
|
||||
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
handlers: 原始处理器列表
|
||||
|
||||
Returns:
|
||||
List: 过滤后的处理器列表
|
||||
"""
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
session_id = event.unified_msg_origin
|
||||
filtered_handlers = []
|
||||
|
||||
for handler in handlers:
|
||||
# 获取处理器对应的插件
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not plugin:
|
||||
# 如果找不到插件元数据,允许执行(可能是系统插件)
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
# 跳过保留插件(系统插件)
|
||||
if plugin.reserved:
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin.name
|
||||
):
|
||||
filtered_handlers.append(handler)
|
||||
else:
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}"
|
||||
)
|
||||
|
||||
return filtered_handlers
|
||||
@@ -1,14 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
star_map: Dict[str, StarMetadata] = {}
|
||||
star_registry: list[StarMetadata] = []
|
||||
star_map: dict[str, StarMetadata] = {}
|
||||
"""key 是模块路径,__module__"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Star
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
@@ -18,22 +22,27 @@ class StarMetadata:
|
||||
当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。
|
||||
"""
|
||||
|
||||
name: str
|
||||
author: str # 插件作者
|
||||
desc: str # 插件简介
|
||||
version: str # 插件版本
|
||||
repo: str = None # 插件仓库地址
|
||||
name: str | None = None
|
||||
"""插件名"""
|
||||
author: str | None = None
|
||||
"""插件作者"""
|
||||
desc: str | None = None
|
||||
"""插件简介"""
|
||||
version: str | None = None
|
||||
"""插件版本"""
|
||||
repo: str | None = None
|
||||
"""插件仓库地址"""
|
||||
|
||||
star_cls_type: type = None
|
||||
star_cls_type: type[Star] | None = None
|
||||
"""插件的类对象的类型"""
|
||||
module_path: str = None
|
||||
module_path: str | None = None
|
||||
"""插件的模块路径"""
|
||||
|
||||
star_cls: object = None
|
||||
star_cls: Star | None = None
|
||||
"""插件的类对象"""
|
||||
module: ModuleType = None
|
||||
module: ModuleType | None = None
|
||||
"""插件的模块对象"""
|
||||
root_dir_name: str = None
|
||||
root_dir_name: str | None = None
|
||||
"""插件的目录名称"""
|
||||
reserved: bool = False
|
||||
"""是否是 AstrBot 的保留插件"""
|
||||
@@ -41,35 +50,14 @@ class StarMetadata:
|
||||
activated: bool = True
|
||||
"""是否被激活"""
|
||||
|
||||
config: AstrBotConfig = None
|
||||
config: AstrBotConfig | None = None
|
||||
"""插件配置"""
|
||||
|
||||
star_handler_full_names: List[str] = field(default_factory=list)
|
||||
star_handler_full_names: list[str] = field(default_factory=list)
|
||||
"""注册的 Handler 的全名列表"""
|
||||
|
||||
supported_platforms: Dict[str, bool] = field(default_factory=dict)
|
||||
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||
|
||||
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||
"""更新插件支持的平台列表
|
||||
|
||||
Args:
|
||||
plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项
|
||||
"""
|
||||
if not plugin_enable_config:
|
||||
return
|
||||
|
||||
# 清空之前的配置
|
||||
self.supported_platforms.clear()
|
||||
|
||||
# 遍历所有平台配置
|
||||
for platform_id, plugins in plugin_enable_config.items():
|
||||
# 检查该插件在当前平台的配置
|
||||
if self.name in plugins:
|
||||
self.supported_platforms[platform_id] = plugins[self.name]
|
||||
else:
|
||||
# 如果没有明确配置,默认为启用
|
||||
self.supported_platforms[platform_id] = True
|
||||
def __repr__(self) -> str:
|
||||
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||
|
||||
@@ -7,6 +7,7 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -26,7 +27,10 @@ class StarHandlerRegistry(Generic[T]):
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
self,
|
||||
event_type: EventType,
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> List[StarHandlerMetadata]:
|
||||
handlers = []
|
||||
for handler in self._handlers:
|
||||
@@ -36,8 +40,15 @@ class StarHandlerRegistry(Generic[T]):
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
if plugins_name is not None and plugins_name != ["*"]:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not plugin:
|
||||
continue
|
||||
if (
|
||||
plugin.name not in plugins_name
|
||||
and event_type != EventType.OnAstrBotLoadedEvent
|
||||
and not plugin.reserved
|
||||
):
|
||||
continue
|
||||
handlers.append(handler)
|
||||
return handlers
|
||||
@@ -49,7 +60,8 @@ class StarHandlerRegistry(Generic[T]):
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
return [
|
||||
handler for handler in self._handlers
|
||||
handler
|
||||
for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
@@ -67,6 +79,7 @@ class StarHandlerRegistry(Generic[T]):
|
||||
def __len__(self):
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
@@ -119,32 +132,3 @@ class StarHandlerMetadata:
|
||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||
"priority", 0
|
||||
)
|
||||
|
||||
def is_enabled_for_platform(self, platform_id: str) -> bool:
|
||||
"""检查插件是否在指定平台启用
|
||||
|
||||
Args:
|
||||
platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例
|
||||
|
||||
Returns:
|
||||
bool: 是否启用,True表示启用,False表示禁用
|
||||
"""
|
||||
plugin = star_map.get(self.handler_module_path)
|
||||
|
||||
# 如果插件元数据不存在,默认允许执行
|
||||
if not plugin or not plugin.name:
|
||||
return True
|
||||
|
||||
# 先检查插件是否被激活
|
||||
if not plugin.activated:
|
||||
return False
|
||||
|
||||
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
|
||||
if (
|
||||
hasattr(plugin, "supported_platforms")
|
||||
and platform_id in plugin.supported_platforms
|
||||
):
|
||||
return plugin.supported_platforms[platform_id]
|
||||
|
||||
# 如果没有缓存数据,默认允许执行
|
||||
return True
|
||||
|
||||
@@ -11,7 +11,6 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -23,6 +22,7 @@ from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from astrbot.core.agent.handoff import HandoffTool, FunctionTool
|
||||
|
||||
from . import StarMetadata
|
||||
from .context import Context
|
||||
@@ -37,12 +37,6 @@ except ImportError:
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig):
|
||||
@@ -64,6 +58,8 @@ class PluginManager:
|
||||
"""保留插件的路径。在 packages 目录下"""
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
"""插件配置 Schema 文件名"""
|
||||
self._pm_lock = asyncio.Lock()
|
||||
"""StarManager操作互斥锁"""
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
@@ -119,7 +115,8 @@ class PluginManager:
|
||||
reloaded_plugins.add(plugin_name)
|
||||
break
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
@staticmethod
|
||||
def _get_classes(arg: ModuleType):
|
||||
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
||||
classes = []
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
@@ -129,7 +126,8 @@ class PluginManager:
|
||||
break
|
||||
return classes
|
||||
|
||||
def _get_modules(self, path):
|
||||
@staticmethod
|
||||
def _get_modules(path):
|
||||
modules = []
|
||||
|
||||
dirs = os.listdir(path)
|
||||
@@ -155,7 +153,7 @@ class PluginManager:
|
||||
)
|
||||
return modules
|
||||
|
||||
def _get_plugin_modules(self) -> List[dict]:
|
||||
def _get_plugin_modules(self) -> list[dict]:
|
||||
plugins = []
|
||||
if os.path.exists(self.plugin_store_path):
|
||||
plugins.extend(self._get_modules(self.plugin_store_path))
|
||||
@@ -166,7 +164,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str | None = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -189,10 +187,11 @@ class PluginManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj=None) -> StarMetadata:
|
||||
"""v3.4.0 以前的方式载入插件元数据
|
||||
@staticmethod
|
||||
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
|
||||
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
|
||||
先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
|
||||
"""
|
||||
metadata = None
|
||||
|
||||
@@ -204,11 +203,14 @@ class PluginManager:
|
||||
os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8"
|
||||
) as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
elif plugin_obj:
|
||||
elif plugin_obj and hasattr(plugin_obj, "info"):
|
||||
# 使用 info() 函数
|
||||
metadata = plugin_obj.info()
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
if "desc" not in metadata and "description" in metadata:
|
||||
metadata["desc"] = metadata["description"]
|
||||
|
||||
if (
|
||||
"name" not in metadata
|
||||
or "desc" not in metadata
|
||||
@@ -228,8 +230,9 @@ class PluginManager:
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_related_modules(
|
||||
self, plugin_root_dir: str, is_reserved: bool = False
|
||||
plugin_root_dir: str, is_reserved: bool = False
|
||||
) -> list[str]:
|
||||
"""获取与指定插件相关的所有已加载模块名
|
||||
|
||||
@@ -251,8 +254,8 @@ class PluginManager:
|
||||
|
||||
def _purge_modules(
|
||||
self,
|
||||
module_patterns: list[str] = None,
|
||||
root_dir_name: str = None,
|
||||
module_patterns: list[str] | None = None,
|
||||
root_dir_name: str | None = None,
|
||||
is_reserved: bool = False,
|
||||
):
|
||||
"""从 sys.modules 中移除指定的模块
|
||||
@@ -293,69 +296,48 @@ class PluginManager:
|
||||
- success (bool): 重载是否成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
if smd.name == specified_plugin_name:
|
||||
specified_module_path = smd.module_path
|
||||
break
|
||||
async with self._pm_lock:
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
if smd.name == specified_plugin_name:
|
||||
specified_module_path = smd.module_path
|
||||
break
|
||||
|
||||
# 终止插件
|
||||
if not specified_module_path:
|
||||
# 重载所有插件
|
||||
for smd in star_registry:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
# 终止插件
|
||||
if not specified_module_path:
|
||||
# 重载所有插件
|
||||
for smd in star_registry:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
if smd.name and smd.module_path:
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
if smd:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
if smd.name:
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
if smd:
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||
)
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
# 更新所有插件的平台兼容性
|
||||
await self.update_all_platform_compatibility()
|
||||
|
||||
return result
|
||||
|
||||
async def update_all_platform_compatibility(self):
|
||||
"""更新所有插件的平台兼容性设置"""
|
||||
# 获取最新的平台插件启用配置
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
logger.debug(
|
||||
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
|
||||
)
|
||||
|
||||
# 遍历所有插件,更新平台兼容性
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin.update_platform_compatibility(plugin_enable_config)
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
|
||||
)
|
||||
|
||||
return True
|
||||
return result
|
||||
|
||||
async def load(self, specified_module_path=None, specified_dir_name=None):
|
||||
"""载入插件。
|
||||
@@ -370,10 +352,9 @@ class PluginManager:
|
||||
- success (bool): 是否全部加载成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
alter_cmd = sp.get("alter_cmd", {})
|
||||
inactivated_plugins = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
|
||||
alter_cmd = await sp.global_get("alter_cmd", {})
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
@@ -435,7 +416,7 @@ class PluginManager:
|
||||
)
|
||||
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
# 通过 __init__subclass__ 注册插件
|
||||
metadata = star_map[path]
|
||||
|
||||
try:
|
||||
@@ -449,13 +430,15 @@ class PluginManager:
|
||||
metadata.desc = metadata_yaml.desc
|
||||
metadata.version = metadata_yaml.version
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"插件 {root_dir_name} 元数据载入失败: {str(e)}。使用默认元数据。"
|
||||
)
|
||||
logger.info(metadata)
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
# metadata.config = plugin_config
|
||||
if plugin_config and metadata.star_cls_type:
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -464,7 +447,7 @@ class PluginManager:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
else:
|
||||
elif metadata.star_cls_type:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context
|
||||
)
|
||||
@@ -475,11 +458,9 @@ class PluginManager:
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
# 更新插件的平台兼容性
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = (
|
||||
@@ -489,20 +470,32 @@ class PluginManager:
|
||||
)
|
||||
for handler in related_handlers:
|
||||
handler.handler = functools.partial(
|
||||
handler.handler, metadata.star_cls
|
||||
handler.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
if (
|
||||
func_tool.handler
|
||||
and func_tool.handler.__module__ == metadata.module_path
|
||||
):
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(
|
||||
func_tool.handler, metadata.star_cls
|
||||
)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
if isinstance(func_tool, HandoffTool):
|
||||
need_apply = []
|
||||
sub_tools = func_tool.agent.tools
|
||||
for sub_tool in sub_tools:
|
||||
if isinstance(sub_tool, FunctionTool):
|
||||
need_apply.append(sub_tool)
|
||||
else:
|
||||
need_apply = [func_tool]
|
||||
|
||||
for ft in need_apply:
|
||||
if (
|
||||
ft.handler
|
||||
and ft.handler.__module__ == metadata.module_path
|
||||
):
|
||||
ft.handler_module_path = metadata.module_path
|
||||
ft.handler = functools.partial(
|
||||
ft.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
if ft.name in inactivated_llm_tools:
|
||||
ft.active = False
|
||||
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
@@ -526,13 +519,12 @@ class PluginManager:
|
||||
obj = getattr(module, classes[0])(
|
||||
context=self.context
|
||||
) # 实例化插件类
|
||||
else:
|
||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||
|
||||
metadata = None
|
||||
metadata = self._load_plugin_metadata(
|
||||
plugin_path=plugin_dir_path, plugin_obj=obj
|
||||
)
|
||||
if not metadata:
|
||||
raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
@@ -547,6 +539,10 @@ class PluginManager:
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
|
||||
full_names = []
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||
metadata.module_path
|
||||
@@ -586,7 +582,7 @@ class PluginManager:
|
||||
metadata.star_handler_full_names = full_names
|
||||
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
@@ -622,43 +618,45 @@ class PluginManager:
|
||||
- readme: README.md 文件的内容(如果存在)
|
||||
如果找不到插件元数据则返回 None。
|
||||
"""
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
await self.load(specified_dir_name=dir_name)
|
||||
async with self._pm_lock:
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
await self.load(specified_dir_name=dir_name)
|
||||
|
||||
# Get the plugin metadata to return repo info
|
||||
plugin = self.context.get_registered_star(dir_name)
|
||||
if not plugin:
|
||||
# Try to find by other name if directory name doesn't match plugin name
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
plugin = star
|
||||
break
|
||||
# Get the plugin metadata to return repo info
|
||||
plugin = self.context.get_registered_star(dir_name)
|
||||
if not plugin:
|
||||
# Try to find by other name if directory name doesn't match plugin name
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
plugin = star
|
||||
break
|
||||
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(plugin_path, "README.md")
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(plugin_path, "README.md")
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path) and nh3:
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
if os.path.exists(readme_path):
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}"
|
||||
)
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {
|
||||
"repo": plugin.repo,
|
||||
"readme": cleaned_content,
|
||||
"name": plugin.name,
|
||||
}
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {
|
||||
"repo": plugin.repo,
|
||||
"readme": readme_content,
|
||||
"name": plugin.name,
|
||||
}
|
||||
|
||||
return plugin_info
|
||||
return plugin_info
|
||||
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
"""卸载指定的插件。
|
||||
@@ -669,32 +667,33 @@ class PluginManager:
|
||||
Raises:
|
||||
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||
"""
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
async with self._pm_lock:
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# 终止插件
|
||||
try:
|
||||
await self._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
|
||||
)
|
||||
# 终止插件
|
||||
try:
|
||||
await self._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
|
||||
)
|
||||
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
|
||||
try:
|
||||
remove_dir(os.path.join(ppath, root_dir_name))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
||||
)
|
||||
try:
|
||||
remove_dir(os.path.join(ppath, root_dir_name))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
||||
)
|
||||
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
"""解绑并移除一个插件。
|
||||
@@ -725,6 +724,9 @@ class PluginManager:
|
||||
]:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
if plugin is None:
|
||||
return
|
||||
|
||||
self._purge_modules(
|
||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||
)
|
||||
@@ -747,35 +749,37 @@ class PluginManager:
|
||||
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
|
||||
并且同时将插件启用的 llm_tool 禁用。
|
||||
"""
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
async with self._pm_lock:
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
|
||||
# 调用插件的终止方法
|
||||
await self._terminate_plugin(plugin)
|
||||
# 调用插件的终止方法
|
||||
await self._terminate_plugin(plugin)
|
||||
|
||||
# 加入到 shared_preferences 中
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
# 加入到 shared_preferences 中
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = list(
|
||||
set(sp.get("inactivated_llm_tools", []))
|
||||
) # 后向兼容
|
||||
inactivated_llm_tools: list = list(
|
||||
set(await sp.global_get("inactivated_llm_tools", []))
|
||||
) # 后向兼容
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
await sp.global_put("inactivated_plugins", inactivated_plugins)
|
||||
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = False
|
||||
plugin.activated = False
|
||||
|
||||
async def _terminate_plugin(self, star_metadata: StarMetadata):
|
||||
@staticmethod
|
||||
async def _terminate_plugin(star_metadata: StarMetadata):
|
||||
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
||||
logger.info(f"正在终止插件 {star_metadata.name} ...")
|
||||
|
||||
@@ -784,20 +788,23 @@ class PluginManager:
|
||||
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
||||
return
|
||||
|
||||
if hasattr(star_metadata.star_cls, "__del__"):
|
||||
if star_metadata.star_cls is None:
|
||||
return
|
||||
|
||||
if '__del__' in star_metadata.star_cls_type.__dict__:
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
)
|
||||
elif hasattr(star_metadata.star_cls, "terminate"):
|
||||
elif 'terminate' in star_metadata.star_cls_type.__dict__:
|
||||
await star_metadata.star_cls.terminate()
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||
if plugin.module_path in inactivated_plugins:
|
||||
inactivated_plugins.remove(plugin.module_path)
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
await sp.global_put("inactivated_plugins", inactivated_plugins)
|
||||
|
||||
# 启用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
@@ -807,7 +814,7 @@ class PluginManager:
|
||||
):
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
func_tool.active = True
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
await self.reload(plugin_name)
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
file_url = update_data[0]["zipball_url"]
|
||||
elif str(version).startswith("v"):
|
||||
# 更新到指定版本
|
||||
logger.info(f"正在更新到指定版本: {version}")
|
||||
for data in update_data:
|
||||
if data["tag_name"] == version:
|
||||
file_url = data["zipball_url"]
|
||||
@@ -98,8 +97,8 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
else:
|
||||
if len(str(version)) != 40:
|
||||
raise Exception("commit hash 长度不正确,应为 40")
|
||||
logger.info(f"正在尝试更新到指定 commit: {version}")
|
||||
file_url = "https://github.com/Soulter/AstrBot/archive/" + version + ".zip"
|
||||
file_url = f"https://github.com/Soulter/AstrBot/archive/{version}.zip"
|
||||
logger.info(f"准备更新至指定版本的 AstrBot Core: {version}")
|
||||
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
@@ -107,6 +106,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
try:
|
||||
await download_file(file_url, "temp.zip")
|
||||
logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...")
|
||||
self.unzip_file("temp.zip", self.MAIN_PATH)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
@@ -8,6 +8,7 @@ import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
import psutil
|
||||
import logging
|
||||
|
||||
import certifi
|
||||
|
||||
@@ -16,6 +17,8 @@ from typing import Union
|
||||
from PIL import Image
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
def on_error(func, path, exc_info):
|
||||
"""
|
||||
@@ -30,7 +33,7 @@ def on_error(func, path, exc_info):
|
||||
raise exc_info[1]
|
||||
|
||||
|
||||
def remove_dir(file_path) -> bool:
|
||||
def remove_dir(file_path: str) -> bool:
|
||||
if not os.path.exists(file_path):
|
||||
return True
|
||||
shutil.rmtree(file_path, onerror=on_error)
|
||||
@@ -212,19 +215,50 @@ async def get_dashboard_version():
|
||||
return None
|
||||
|
||||
|
||||
async def download_dashboard(path: str = None, extract_path: str = "data"):
|
||||
async def download_dashboard(
|
||||
path: str | None = None,
|
||||
extract_path: str = "data",
|
||||
latest: bool = True,
|
||||
version: str | None = None,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
"""下载管理面板文件"""
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = (
|
||||
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
if latest or len(str(version)) != 40:
|
||||
logger.info("准备下载最新发行版本的 AstrBot WebUI")
|
||||
ver_name = "latest" if latest else version
|
||||
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
||||
try:
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
except BaseException as _:
|
||||
if latest:
|
||||
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
else:
|
||||
dashboard_release_url = f"https://github.com/Soulter/AstrBot/releases/download/{version}/dist.zip"
|
||||
if proxy:
|
||||
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
else:
|
||||
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
|
||||
|
||||
url = (
|
||||
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
if proxy:
|
||||
url = f"{proxy}/{url}"
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
releases = await resp.json()
|
||||
for release in releases:
|
||||
if version in release["tag_name"]:
|
||||
download_url = release["assets"][0]["browser_download_url"]
|
||||
await download_file(download_url, path, show_progress=True)
|
||||
else:
|
||||
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
|
||||
return
|
||||
|
||||
with zipfile.ZipFile(path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
|
||||
@@ -58,9 +58,10 @@ class Metric:
|
||||
pass
|
||||
try:
|
||||
if "adapter_name" in kwargs:
|
||||
db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1})
|
||||
if "llm_name" in kwargs:
|
||||
db_helper.insert_llm_metrics({kwargs["llm_name"]: 1})
|
||||
await db_helper.insert_platform_stats(
|
||||
platform_id=kwargs["adapter_name"],
|
||||
platform_type=kwargs.get("adapter_type", "unknown"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"保存指标到数据库失败: {e}")
|
||||
pass
|
||||
|
||||
29
astrbot/core/utils/session_lock.py
Normal file
29
astrbot/core/utils/session_lock.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
class SessionLockManager:
|
||||
def __init__(self):
|
||||
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._lock_count: dict[str, int] = defaultdict(int)
|
||||
self._access_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire_lock(self, session_id: str):
|
||||
async with self._access_lock:
|
||||
lock = self._locks[session_id]
|
||||
self._lock_count[session_id] += 1
|
||||
|
||||
try:
|
||||
async with lock:
|
||||
yield
|
||||
finally:
|
||||
async with self._access_lock:
|
||||
self._lock_count[session_id] -= 1
|
||||
if self._lock_count[session_id] == 0:
|
||||
self._locks.pop(session_id, None)
|
||||
self._lock_count.pop(session_id, None)
|
||||
|
||||
|
||||
session_lock_manager = SessionLockManager()
|
||||
@@ -1,41 +1,180 @@
|
||||
import json
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Preference
|
||||
import threading
|
||||
import asyncio
|
||||
import os
|
||||
from typing import TypeVar, Any, overload
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
def __init__(self, db_helper: BaseDatabase, json_storage_path=None):
|
||||
if json_storage_path is None:
|
||||
json_storage_path = os.path.join(
|
||||
get_astrbot_data_path(), "shared_preferences.json"
|
||||
)
|
||||
self.path = json_storage_path
|
||||
self.db_helper = db_helper
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=self._sync_loop.run_forever, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
async def get_async(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
) -> _VT:
|
||||
"""获取指定范围和键的偏好设置"""
|
||||
if scope_id is not None and key is not None:
|
||||
result = await self.db_helper.get_preference(scope, scope_id, key)
|
||||
if result:
|
||||
ret = result.value["val"]
|
||||
else:
|
||||
ret = default
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference."
|
||||
)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._data.get(key, default)
|
||||
async def range_get_async(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""获取指定范围的偏好设置
|
||||
Note: 返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。scope_id 和 key 可以为 None,这时返回该范围下所有的偏好设置。
|
||||
"""
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: None, key: str, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: str, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: None, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
async def session_get(
|
||||
self, umo: str | None, key: str | None = None, default: _VT = None
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取会话范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if umo is None or key is None:
|
||||
return await self.range_get_async("umo", umo, key)
|
||||
return await self.get_async("umo", umo, key, default)
|
||||
|
||||
@overload
|
||||
async def global_get(self, key: None, default: Any = None) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def global_get(self, key: str, default: _VT = None) -> _VT: ...
|
||||
|
||||
async def global_get(
|
||||
self, key: str | None, default: _VT = None
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取全局范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if key is None:
|
||||
return await self.range_get_async("global", "global", key)
|
||||
return await self.get_async("global", "global", key, default)
|
||||
|
||||
async def put_async(self, scope: str, scope_id: str, key: str, value: Any):
|
||||
"""设置指定范围和键的偏好设置"""
|
||||
await self.db_helper.insert_preference_or_update(
|
||||
scope, scope_id, key, {"val": value}
|
||||
)
|
||||
|
||||
async def session_put(self, umo: str, key: str, value: Any):
|
||||
await self.put_async("umo", umo, key, value)
|
||||
|
||||
async def global_put(self, key: str, value: Any):
|
||||
await self.put_async("global", "global", key, value)
|
||||
|
||||
async def remove_async(self, scope: str, scope_id: str, key: str):
|
||||
"""删除指定范围和键的偏好设置"""
|
||||
await self.db_helper.remove_preference(scope, scope_id, key)
|
||||
|
||||
async def session_remove(self, umo: str, key: str):
|
||||
await self.remove_async("umo", umo, key)
|
||||
|
||||
async def global_remove(self, key: str):
|
||||
"""删除全局偏好设置"""
|
||||
await self.remove_async("global", "global", key)
|
||||
|
||||
async def clear_async(self, scope: str, scope_id: str):
|
||||
"""清空指定范围的所有偏好设置"""
|
||||
await self.db_helper.clear_preferences(scope, scope_id)
|
||||
|
||||
# ====
|
||||
# DEPRECATED METHODS
|
||||
# ====
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
scope: str | None = None,
|
||||
scope_id: str | None = "",
|
||||
) -> _VT:
|
||||
"""获取偏好设置(已弃用)"""
|
||||
if scope_id == "":
|
||||
scope_id = "unknown"
|
||||
if scope_id is None or key is None:
|
||||
# result = asyncio.run(self.range_get_async(scope, scope_id, key))
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference."
|
||||
)
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
self.get_async(scope or "unknown", scope_id or "unknown", key, default),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
return result if result is not None else default
|
||||
|
||||
def range_get(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""获取指定范围的偏好设置(已弃用)"""
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
self.range_get_async(scope, scope_id, key), self._sync_loop
|
||||
).result()
|
||||
|
||||
return result
|
||||
|
||||
def put(self, key, value, scope: str | None = None, scope_id: str | None = None):
|
||||
"""设置偏好设置(已弃用)"""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.put_async(scope or "unknown", scope_id or "unknown", key, value),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
def remove(self, key, scope: str | None = None, scope_id: str | None = None):
|
||||
"""删除偏好设置(已弃用)"""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_async(scope or "unknown", scope_id or "unknown", key),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
def clear(self, scope: str | None = None, scope_id: str | None = None):
|
||||
"""清空偏好设置(已弃用)"""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.clear_async(scope or "unknown", scope_id or "unknown"),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
@@ -1,37 +1,76 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import certifi
|
||||
|
||||
import logging
|
||||
import random
|
||||
from . import RenderStrategy
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
CUSTOM_T2I_TEMPLATE_PATH = os.path.join(get_astrbot_data_path(), "t2i_template.html")
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
class NetworkRenderStrategy(RenderStrategy):
|
||||
def __init__(self, base_url: str | None = None) -> None:
|
||||
super().__init__()
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
|
||||
self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
else:
|
||||
self.BASE_RENDER_URL = self._clean_url(base_url)
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template", "base.html")
|
||||
with open(self.TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
self.DEFAULT_TEMPLATE = f.read()
|
||||
|
||||
if self.BASE_RENDER_URL.endswith("/"):
|
||||
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
|
||||
if not self.BASE_RENDER_URL.endswith("text2img"):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
self.endpoints = [self.BASE_RENDER_URL]
|
||||
|
||||
def set_endpoint(self, base_url: str):
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
async def initialize(self):
|
||||
if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT:
|
||||
asyncio.create_task(self.get_official_endpoints())
|
||||
|
||||
if self.BASE_RENDER_URL.endswith("/"):
|
||||
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
|
||||
if not self.BASE_RENDER_URL.endswith("text2img"):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
async def get_template(self) -> str:
|
||||
"""获取文转图 HTML 模板
|
||||
|
||||
Returns:
|
||||
str: 文转图 HTML 模板字符串
|
||||
"""
|
||||
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
|
||||
with open(CUSTOM_T2I_TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
return self.DEFAULT_TEMPLATE
|
||||
|
||||
async def get_official_endpoints(self):
|
||||
"""获取官方的 t2i 端点列表。"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.soulter.top/astrbot/t2i-endpoints"
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
all_endpoints: list[dict] = data.get("data", [])
|
||||
self.endpoints = [
|
||||
ep.get("url")
|
||||
for ep in all_endpoints
|
||||
if ep.get("active") and ep.get("url")
|
||||
]
|
||||
logger.info(
|
||||
f"Successfully got {len(self.endpoints)} official T2I endpoints."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get official endpoints: {e}")
|
||||
|
||||
def _clean_url(self, url: str):
|
||||
if url.endswith("/"):
|
||||
url = url[:-1]
|
||||
if not url.endswith("text2img"):
|
||||
url += "/text2img"
|
||||
return url
|
||||
|
||||
async def render_custom_template(
|
||||
self,
|
||||
@@ -41,6 +80,7 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
"""使用自定义文转图模板"""
|
||||
|
||||
default_options = {"full_page": True, "type": "jpeg", "quality": 40}
|
||||
if options:
|
||||
default_options |= options
|
||||
@@ -51,30 +91,44 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
"tmpldata": tmpl_data,
|
||||
"options": default_options,
|
||||
}
|
||||
if return_url:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, connector=connector
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.BASE_RENDER_URL}/generate", json=post_data
|
||||
) as resp:
|
||||
ret = await resp.json()
|
||||
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
|
||||
return await download_image_by_url(
|
||||
f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data
|
||||
)
|
||||
|
||||
endpoints = self.endpoints.copy() if self.endpoints else [self.BASE_RENDER_URL]
|
||||
random.shuffle(endpoints)
|
||||
last_exception = None
|
||||
for endpoint in endpoints:
|
||||
try:
|
||||
if return_url:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, connector=connector
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{endpoint}/generate", json=post_data
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
ret = await resp.json()
|
||||
return f"{endpoint}/{ret['data']['id']}"
|
||||
else:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
else:
|
||||
# download_image_by_url 失败时抛异常
|
||||
return await download_image_by_url(
|
||||
f"{endpoint}/generate", post=True, post_data=post_data
|
||||
)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(f"Endpoint {endpoint} failed: {e}, trying next...")
|
||||
continue
|
||||
# 全部失败
|
||||
logger.error(f"All endpoints failed: {last_exception}")
|
||||
raise RuntimeError(f"All endpoints failed: {last_exception}")
|
||||
|
||||
async def render(self, text: str, return_url: bool = False) -> str:
|
||||
"""
|
||||
返回图像的文件路径
|
||||
"""
|
||||
with open(
|
||||
os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding="utf-8"
|
||||
) as f:
|
||||
tmpl_str = f.read()
|
||||
assert tmpl_str
|
||||
tmpl_str = await self.get_template()
|
||||
text = text.replace("`", "\\`")
|
||||
return await self.render_custom_template(
|
||||
tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user