Compare commits
207 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07ba9c772c | ||
|
|
0622d88b22 | ||
|
|
594f0fed55 | ||
|
|
04b0d9b88d | ||
|
|
1f2af8ef94 | ||
|
|
598ea2d857 | ||
|
|
6dd9bbb516 | ||
|
|
3cd0b47dc6 | ||
|
|
65c71b5f20 | ||
|
|
1152b11202 | ||
|
|
51246ea31b | ||
|
|
7e5592dd32 | ||
|
|
c6b28caebf | ||
|
|
ca002f6fff | ||
|
|
14ec392091 | ||
|
|
5e2eb91ac0 | ||
|
|
c1626613ce | ||
|
|
42042d9e73 | ||
|
|
22c3b53ab8 | ||
|
|
090c32c90e | ||
|
|
4f4a9b9e55 | ||
|
|
6c7d7c9015 | ||
|
|
562e62a8c0 | ||
|
|
0823f7aa48 | ||
|
|
eb201c0420 | ||
|
|
6cfed9a39d | ||
|
|
33618c4a6b | ||
|
|
ace0a7c219 | ||
|
|
f7d018cf94 | ||
|
|
8ae2a556e4 | ||
|
|
4188deb386 | ||
|
|
82cf4ed909 | ||
|
|
88fc437abc | ||
|
|
57f868cab1 | ||
|
|
6cb5527894 | ||
|
|
016783a1e5 | ||
|
|
594ccff9c8 | ||
|
|
30792f0584 | ||
|
|
8f021eb35a | ||
|
|
1969abc340 | ||
|
|
b1b53ab983 | ||
|
|
9b5af23982 | ||
|
|
4cedc6d3c8 | ||
|
|
4e9cce76da | ||
|
|
9b004f3d2f | ||
|
|
9430e3090d | ||
|
|
ba44f9117b | ||
|
|
eb56710a72 | ||
|
|
38e3f27899 | ||
|
|
3c58d96db5 | ||
|
|
a6be0cc135 | ||
|
|
a53510bc41 | ||
|
|
1fd482e899 | ||
|
|
2f130ba009 | ||
|
|
e6d9db9395 | ||
|
|
e0ac743cdb | ||
|
|
b0d3fc11f0 | ||
|
|
7e0a50fbf2 | ||
|
|
59df244173 | ||
|
|
deb31a02cf | ||
|
|
e3aa1315ae | ||
|
|
65bc5efa19 | ||
|
|
abc4bc24b4 | ||
|
|
5df3f06f83 | ||
|
|
0e1de82bd7 | ||
|
|
f31e41b3f1 | ||
|
|
fe8d2718c4 | ||
|
|
8afefada0a | ||
|
|
745e1c37c0 | ||
|
|
fdb5988cec | ||
|
|
36ffcf3cc3 | ||
|
|
a0f8f3ae32 | ||
|
|
130f52f315 | ||
|
|
a05868cc45 | ||
|
|
2fc77aed15 | ||
|
|
c56edb4da6 | ||
|
|
6672190760 | ||
|
|
f122b17097 | ||
|
|
2c5f68e696 | ||
|
|
e1ca645a32 | ||
|
|
333bf56ddc | ||
|
|
b240594859 | ||
|
|
beccae933f | ||
|
|
e6aa1d2c54 | ||
|
|
5e808bab65 | ||
|
|
361d78247b | ||
|
|
3550103e45 | ||
|
|
8b0d4d4de4 | ||
|
|
dc71c04b67 | ||
|
|
a0254ed817 | ||
|
|
2563ecf3c5 | ||
|
|
c04738d9fe | ||
|
|
1266b4d086 | ||
|
|
99cf0a1522 | ||
|
|
98a75e923d | ||
|
|
ad96d676e6 | ||
|
|
79333bbc35 | ||
|
|
5c5b0f4fde | ||
|
|
ed6cdfedbb | ||
|
|
23f13ef05f | ||
|
|
f9c59d9706 | ||
|
|
e1cec42227 | ||
|
|
8d79c50d53 | ||
|
|
d77830b97f | ||
|
|
394540f689 | ||
|
|
7d776e0ce2 | ||
|
|
17df1692b9 | ||
|
|
9ab652641d | ||
|
|
9119f7166f | ||
|
|
da7d9d8eb9 | ||
|
|
80fccc90b7 | ||
|
|
dcebc70f1a | ||
|
|
259e7bc322 | ||
|
|
37bdb6c6f6 | ||
|
|
dc71afdd3f | ||
|
|
44638108d0 | ||
|
|
93fcac498c | ||
|
|
79e2743aac | ||
|
|
5e9c7cdd91 | ||
|
|
6f73e5087d | ||
|
|
8c120b020e | ||
|
|
12fc6f9d38 | ||
|
|
a6e8483b4c | ||
|
|
7191d28ada | ||
|
|
e6b5e3d282 | ||
|
|
1413d6b5fe | ||
|
|
dcd8a1094c | ||
|
|
e64b31b9ba | ||
|
|
080f347511 | ||
|
|
eaaff4298d | ||
|
|
dd5a02e8ef | ||
|
|
3211ec57ee | ||
|
|
6796afdaee | ||
|
|
cc6fe57773 | ||
|
|
1dfc831938 | ||
|
|
cafeda4abf | ||
|
|
d951b99718 | ||
|
|
0ad87209e5 | ||
|
|
1b50c5404d | ||
|
|
3007f67cab | ||
|
|
ee08659f01 | ||
|
|
baf5ad0fab | ||
|
|
8bdd748aec | ||
|
|
cef0c22f52 | ||
|
|
13d3fc5cfe | ||
|
|
b91141e2be | ||
|
|
f8a4b54165 | ||
|
|
afe007ca0b | ||
|
|
8a9a044f95 | ||
|
|
5eaf03e227 | ||
|
|
a8437d9331 | ||
|
|
e0392fa98b | ||
|
|
68ff8951de | ||
|
|
9c6b31e71c | ||
|
|
50f74f5ba2 | ||
|
|
b9de2aef60 | ||
|
|
7a47598538 | ||
|
|
3c8c28ebd5 | ||
|
|
524285f767 | ||
|
|
c2a34475f1 | ||
|
|
a69195a02b | ||
|
|
19d7438499 | ||
|
|
ccb380ce06 | ||
|
|
a35c439bbd | ||
|
|
09d1f96603 | ||
|
|
26aa18d980 | ||
|
|
d10b542797 | ||
|
|
ce4e4fb8dd | ||
|
|
8f4a31cf8c | ||
|
|
23549f13d6 | ||
|
|
869d11f9a6 | ||
|
|
02e73b82ee | ||
|
|
f85f87f545 | ||
|
|
1fff5713f3 | ||
|
|
8453ec36f0 | ||
|
|
d5b3ce8424 | ||
|
|
80cbbfa5ca | ||
|
|
9177bb660f | ||
|
|
a3df39a01a | ||
|
|
25dce05cbb | ||
|
|
1542ea3e03 | ||
|
|
6084abbcfe | ||
|
|
ed19b63914 | ||
|
|
4efeb85296 | ||
|
|
fc76665615 | ||
|
|
3a044bb71a | ||
|
|
cddd606562 | ||
|
|
7a5bc51c11 | ||
|
|
9f939b4b6f | ||
|
|
80a86f5b1b | ||
|
|
a0ce1855ab | ||
|
|
a4b43b884a | ||
|
|
824c0f6667 | ||
|
|
a030fe8491 | ||
|
|
3a9429e8ef | ||
|
|
c4eb1ab748 | ||
|
|
29ed19d600 | ||
|
|
0cc65513a5 | ||
|
|
debc048659 | ||
|
|
92f5c918dd | ||
|
|
9519f1e8e2 | ||
|
|
a8f874bf05 | ||
|
|
9d9917e45b | ||
|
|
91ee0a870d | ||
|
|
6cbbffc5a9 | ||
|
|
8f26fd34d1 | ||
|
|
fda655f6d7 |
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
|
||||
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
|
||||
2
.github/auto_assign.yml
vendored
2
.github/auto_assign.yml
vendored
@@ -11,6 +11,8 @@ reviewers:
|
||||
- Larch-C
|
||||
- anka-afk
|
||||
- advent259141
|
||||
- Fridemn
|
||||
- LIghtJUNction
|
||||
# - zouyonghe
|
||||
|
||||
# A number of reviewers added to the pull request
|
||||
|
||||
4
.github/workflows/code-format.yml
vendored
4
.github/workflows/code-format.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
@@ -88,6 +88,6 @@ jobs:
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
11
.github/workflows/dashboard_ci.yml
vendored
11
.github/workflows/dashboard_ci.yml
vendored
@@ -13,11 +13,18 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 'latest'
|
||||
|
||||
- name: npm install, build
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
npm install pnpm -g
|
||||
pnpm install
|
||||
pnpm i --save-dev @types/markdown-it
|
||||
pnpm run build
|
||||
|
||||
- name: Inject Commit SHA
|
||||
id: get_sha
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -31,3 +31,5 @@ packages/python_interpreter/workplace
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
|
||||
uv.lock
|
||||
@@ -6,8 +6,20 @@ ci:
|
||||
autoupdate_schedule: weekly
|
||||
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [ python, pyi ]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
|
||||
19
Dockerfile
19
Dockerfile
@@ -4,8 +4,6 @@ WORKDIR /AstrBot
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
nodejs \
|
||||
npm \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
@@ -13,23 +11,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libssl-dev \
|
||||
ca-certificates \
|
||||
bash \
|
||||
ffmpeg \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||
|
||||
# 释出 ffmpeg
|
||||
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||
|
||||
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
|
||||
|
||||
|
||||
152
README.md
152
README.md
@@ -1,28 +1,38 @@
|
||||
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">文档</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
||||
|
||||
## 主要功能
|
||||
|
||||
@@ -34,7 +44,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
|
||||
|
||||
## 部署方式
|
||||
|
||||
#### Docker 部署
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
@@ -62,7 +72,7 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
@@ -100,7 +110,6 @@ uv run main.py
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:975206796
|
||||
- 开发者群(备份):295657329
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
@@ -110,49 +119,83 @@ uv run main.py
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
**官方维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(官方平台) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 企微应用 | ✔ |
|
||||
| 企微智能机器人 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| Satori | ✔ |
|
||||
| Misskey | ✔ |
|
||||
| Whatsapp | 将支持 |
|
||||
| LINE | 将支持 |
|
||||
|
||||
**社区维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
||||
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | 文本生成 | |
|
||||
| Google Gemini | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||
**大模型服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | |
|
||||
| Google Gemini | ✔ | |
|
||||
| Moonshot AI | ✔ | |
|
||||
| 智谱 AI | ✔ | |
|
||||
| DeepSeek | ✔ | |
|
||||
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
||||
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
||||
| 硅基流动 | ✔ | |
|
||||
| PPIO 派欧云 | ✔ | |
|
||||
| ModelScope | ✔ | |
|
||||
| OneAPI | ✔ | |
|
||||
| Dify | ✔ | |
|
||||
| 阿里云百炼应用 | ✔ | |
|
||||
| Coze | ✔ | |
|
||||
|
||||
**语音转文本服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| Whisper | ✔ | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 本地部署 |
|
||||
|
||||
**文本转语音服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI TTS | ✔ | |
|
||||
| Gemini TTS | ✔ | |
|
||||
| GSVI | ✔ | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | GPT-Sovits |
|
||||
| FishAudio | ✔ | |
|
||||
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | |
|
||||
| Azure TTS | ✔ | |
|
||||
| Minimax TTS | ✔ | |
|
||||
| 火山引擎 TTS | ✔ | |
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -167,12 +210,11 @@ uv run main.py
|
||||
AstrBot 使用 `ruff` 进行代码格式化和检查。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Soulter/AstrBot
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
@@ -185,29 +227,17 @@ pre-commit install
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||
|
||||
另外,一些同类型其他的活跃开源 Bot 项目:
|
||||
|
||||
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
|
||||
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
||||
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
||||
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
||||
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
||||
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
18
README_en.md
18
README_en.md
@@ -10,16 +10,16 @@ _✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
|
||||
</div>
|
||||
|
||||
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
|
||||
@@ -49,7 +49,7 @@ Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app
|
||||
|
||||
#### Replit Deployment
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### CasaOS Deployment
|
||||
|
||||
@@ -67,8 +67,8 @@ See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
|
||||
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
|
||||
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
||||
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
||||
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
||||
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
||||
| Feishu | ✔ | Group chats | Text, Images |
|
||||
| WeChat Open Platform | 🚧 | Planned | - |
|
||||
| Discord | 🚧 | Planned | - |
|
||||
@@ -157,7 +157,7 @@ _✨ Built-in Web Chat Interface ✨_
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
[](https://star-history.com/#AstrBotDevs/AstrBot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
@@ -169,7 +169,7 @@ _✨ Built-in Web Chat Interface ✨_
|
||||
|
||||
<!-- ## ✨ ATRI [Beta]
|
||||
|
||||
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
|
||||
|
||||
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
|
||||
2. Long-term memory
|
||||
|
||||
@@ -10,16 +10,16 @@ _✨ 簡単に使えるマルチプラットフォーム LLM チャットボッ
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
|
||||
</div>
|
||||
|
||||
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
||||
@@ -50,7 +50,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
|
||||
#### Replit デプロイ
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### CasaOS デプロイ
|
||||
|
||||
|
||||
0
astrbot.lock
Normal file
0
astrbot.lock
Normal file
@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
|
||||
class Agent(Generic[TContext]):
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
tools: list[str, FunctionTool] | None = None
|
||||
tools: list[str | FunctionTool] | None = None
|
||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||
|
||||
@@ -40,8 +40,15 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
if transport_type == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
@@ -92,7 +99,7 @@ class MCPClient:
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
self.name = None
|
||||
self.name: str | None = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
@@ -121,7 +128,14 @@ class MCPClient:
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
|
||||
if transport_type != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
@@ -134,7 +148,7 @@ class MCPClient:
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
@@ -159,7 +173,7 @@ class MCPClient:
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
@@ -198,6 +212,8 @@ class MCPClient:
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
if not self.session:
|
||||
raise Exception("MCP Client is not initialized")
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
@@ -198,9 +198,49 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
if not func_tool:
|
||||
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: 未找到工具 {func_tool_name}",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
valid_params = {} # 参数过滤:只传递函数实际需要的参数
|
||||
|
||||
# 获取实际的 handler 函数
|
||||
if func_tool.handler:
|
||||
logger.debug(
|
||||
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
|
||||
)
|
||||
if func_tool.parameters and func_tool.parameters.get("properties"):
|
||||
expected_params = set(func_tool.parameters["properties"].keys())
|
||||
|
||||
valid_params = {
|
||||
k: v
|
||||
for k, v in func_tool_args.items()
|
||||
if k in expected_params
|
||||
}
|
||||
|
||||
# 记录被忽略的参数
|
||||
ignored_params = set(func_tool_args.keys()) - set(
|
||||
valid_params.keys()
|
||||
)
|
||||
if ignored_params:
|
||||
logger.warning(
|
||||
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
|
||||
)
|
||||
else:
|
||||
# 如果没有 handler(如 MCP 工具),使用所有参数
|
||||
valid_params = func_tool_args
|
||||
logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
self.run_context, func_tool, valid_params
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||
@@ -208,11 +248,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
executor = self.tool_executor.execute(
|
||||
tool=func_tool,
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
**valid_params, # 只传递有效的参数
|
||||
)
|
||||
async for resp in executor:
|
||||
|
||||
_final_resp: CallToolResult | None = None
|
||||
async for resp in executor: # type: ignore
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
_final_resp = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -258,7 +301,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
).base64_image(res.content[0].data)
|
||||
).base64_image(resource.blob)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -269,17 +312,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context,
|
||||
func_tool_name,
|
||||
func_tool_args,
|
||||
resp,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
@@ -289,27 +321,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool, func_tool_args, _final_resp
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||
|
||||
self.run_context.event.clear_result()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
from typing import Awaitable, Literal, Any, Optional
|
||||
from typing import Awaitable, Callable, Literal, Any, Optional
|
||||
from .mcp_client import MCPClient
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
|
||||
name: str | None = None
|
||||
name: str
|
||||
parameters: dict | None = None
|
||||
description: str | None = None
|
||||
handler: Awaitable | None = None
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str | None = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
@@ -51,7 +51,7 @@ class ToolSet:
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] = None):
|
||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
@@ -79,7 +79,13 @@ class ToolSet:
|
||||
return None
|
||||
|
||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
):
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -104,7 +110,7 @@ class ToolSet:
|
||||
self.remove_tool(name)
|
||||
|
||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||
def get_func(self, name: str) -> list[FunctionTool]:
|
||||
def get_func(self, name: str) -> FunctionTool | None:
|
||||
"""Get all function tools."""
|
||||
return self.get_tool(name)
|
||||
|
||||
@@ -125,7 +131,11 @@ class ToolSet:
|
||||
},
|
||||
}
|
||||
|
||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
||||
if (
|
||||
tool.parameters
|
||||
and tool.parameters.get("properties")
|
||||
or not omit_empty_parameter_field
|
||||
):
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
@@ -135,14 +145,14 @@ class ToolSet:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
input_schema = {"type": "object"}
|
||||
if tool.parameters:
|
||||
input_schema["properties"] = tool.parameters.get("properties", {})
|
||||
input_schema["required"] = tool.parameters.get("required", [])
|
||||
tool_def = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
"input_schema": input_schema,
|
||||
}
|
||||
result.append(tool_def)
|
||||
return result
|
||||
@@ -210,14 +220,15 @@ class ToolSet:
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
tools = []
|
||||
for tool in self.tools:
|
||||
d = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": convert_schema(tool.parameters),
|
||||
}
|
||||
for tool in self.tools
|
||||
]
|
||||
if tool.parameters:
|
||||
d["parameters"] = convert_schema(tool.parameters)
|
||||
tools.append(d)
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
|
||||
@@ -9,3 +9,4 @@ class AstrAgentContext:
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||
from typing import TypeVar, TypedDict
|
||||
|
||||
@@ -15,14 +16,12 @@ class ConfInfo(TypedDict):
|
||||
"""Configuration information for a specific session or platform."""
|
||||
|
||||
id: str # UUID of the configuration or "default"
|
||||
umop: list[str] # Unified Message Origin Pattern
|
||||
name: str
|
||||
path: str # File name to the configuration file
|
||||
|
||||
|
||||
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
id="default",
|
||||
umop=["::"],
|
||||
name="default",
|
||||
path=ASTRBOT_CONFIG_PATH,
|
||||
)
|
||||
@@ -31,8 +30,14 @@ DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
class AstrBotConfigManager:
|
||||
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
||||
|
||||
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
|
||||
def __init__(
|
||||
self,
|
||||
default_config: AstrBotConfig,
|
||||
ucr: UmopConfigRouter,
|
||||
sp: SharedPreferences,
|
||||
):
|
||||
self.sp = sp
|
||||
self.ucr = ucr
|
||||
self.confs: dict[str, AstrBotConfig] = {}
|
||||
"""uuid / "default" -> AstrBotConfig"""
|
||||
self.confs["default"] = default_config
|
||||
@@ -63,24 +68,15 @@ class AstrBotConfigManager:
|
||||
)
|
||||
continue
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||
|
||||
Returns:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
"""
|
||||
# uuid -> { "umop": list, "path": str, "name": str }
|
||||
# uuid -> { "path": str, "name": str }
|
||||
abconf_data = self._get_abconf_data()
|
||||
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
@@ -89,10 +85,13 @@ class AstrBotConfigManager:
|
||||
except Exception:
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
for uuid_, meta in abconf_data.items():
|
||||
for pattern in meta["umop"]:
|
||||
if self._is_umo_match(pattern, umo):
|
||||
return ConfInfo(**meta, id=uuid_)
|
||||
conf_id = self.ucr.get_conf_id_for_umop(umo)
|
||||
if conf_id:
|
||||
meta = abconf_data.get(conf_id)
|
||||
if meta and isinstance(meta, dict):
|
||||
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
|
||||
meta.pop("umop", None)
|
||||
return ConfInfo(**meta, id=conf_id)
|
||||
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
@@ -100,23 +99,14 @@ class AstrBotConfigManager:
|
||||
self,
|
||||
abconf_path: str,
|
||||
abconf_id: str,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
abconf_name: str | None = None,
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
"umop": umo_parts,
|
||||
"path": abconf_path,
|
||||
"name": random_word,
|
||||
}
|
||||
@@ -153,29 +143,26 @@ class AstrBotConfigManager:
|
||||
def get_conf_list(self) -> list[ConfInfo]:
|
||||
"""获取所有配置文件的元数据列表"""
|
||||
conf_list = []
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
abconf_mapping = self._get_abconf_data()
|
||||
for uuid_, meta in abconf_mapping.items():
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
meta.pop("umop", None)
|
||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
return conf_list
|
||||
|
||||
def create_conf(
|
||||
self,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
config: dict = DEFAULT_CONFIG,
|
||||
name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
|
||||
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
||||
"""
|
||||
conf_uuid = str(uuid.uuid4())
|
||||
conf_file_name = f"abconf_{conf_uuid}.json"
|
||||
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
||||
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
||||
conf.save_config()
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name)
|
||||
self.confs[conf_uuid] = conf
|
||||
return conf_uuid
|
||||
|
||||
@@ -228,15 +215,12 @@ class AstrBotConfigManager:
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
||||
) -> bool:
|
||||
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
name: 新的配置文件名称 (可选)
|
||||
umo_parts: 新的 UMO 部分列表 (可选)
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
@@ -255,18 +239,6 @@ class AstrBotConfigManager:
|
||||
if name is not None:
|
||||
abconf_data[conf_id]["name"] = name
|
||||
|
||||
# 更新 UMO 部分
|
||||
if umo_parts is not None:
|
||||
# 验证 UMO 部分格式
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
self.abconf_data = abconf_data
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.1.2"
|
||||
VERSION = "4.5.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -57,19 +57,22 @@ DEFAULT_CONFIG = {
|
||||
"web_search": False,
|
||||
"websearch_provider": "default",
|
||||
"websearch_tavily_key": [],
|
||||
"websearch_baidu_app_builder_key": "",
|
||||
"web_search_link": False,
|
||||
"display_reasoning_text": False,
|
||||
"identifier": False,
|
||||
"group_name_display": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "",
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -115,6 +118,15 @@ DEFAULT_CONFIG = {
|
||||
"port": 6185,
|
||||
},
|
||||
"platform": [],
|
||||
"platform_specific": {
|
||||
# 平台特异配置:按平台分类,平台下按功能分组
|
||||
"lark": {
|
||||
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
|
||||
},
|
||||
"telegram": {
|
||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||
},
|
||||
},
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"pip_install_arg": "",
|
||||
@@ -122,8 +134,11 @@ DEFAULT_CONFIG = {
|
||||
"persona": [], # deprecated
|
||||
"timezone": "Asia/Shanghai",
|
||||
"callback_api_base": "",
|
||||
"default_kb_collection": "", # 默认知识库名称
|
||||
"default_kb_collection": "", # 默认知识库名称, 已经过时
|
||||
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
|
||||
"kb_names": [], # 默认知识库名称列表
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
}
|
||||
|
||||
|
||||
@@ -150,10 +165,11 @@ CONFIG_METADATA_2 = {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"is_sandbox": False,
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"QQ 个人号(aiocqhttp)": {
|
||||
"QQ 个人号(OneBot v11)": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -161,7 +177,7 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"微信个人号(WeChatPadPro)": {
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
@@ -197,6 +213,18 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
},
|
||||
"企业微信智能机器人": {
|
||||
"id": "wecom_ai_bot",
|
||||
"type": "wecom_ai_bot",
|
||||
"enable": True,
|
||||
"wecomaibot_init_respond_text": "💭 思考中...",
|
||||
"wecomaibot_friend_message_welcome_text": "",
|
||||
"wecom_ai_bot_name": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6198,
|
||||
},
|
||||
"飞书(Lark)": {
|
||||
"id": "lark",
|
||||
"type": "lark",
|
||||
@@ -235,6 +263,24 @@ CONFIG_METADATA_2 = {
|
||||
"discord_guild_id_for_debug": "",
|
||||
"discord_activity_name": "",
|
||||
},
|
||||
"Misskey": {
|
||||
"id": "misskey",
|
||||
"type": "misskey",
|
||||
"enable": False,
|
||||
"misskey_instance_url": "https://misskey.example",
|
||||
"misskey_token": "",
|
||||
"misskey_default_visibility": "public",
|
||||
"misskey_local_only": False,
|
||||
"misskey_enable_chat": True,
|
||||
# download / security options
|
||||
"misskey_allow_insecure_downloads": False,
|
||||
"misskey_download_timeout": 15,
|
||||
"misskey_download_chunk_size": 65536,
|
||||
"misskey_max_download_bytes": None,
|
||||
"misskey_enable_file_upload": True,
|
||||
"misskey_upload_concurrency": 3,
|
||||
"misskey_upload_folder": "",
|
||||
},
|
||||
"Slack": {
|
||||
"id": "slack",
|
||||
"type": "slack",
|
||||
@@ -252,43 +298,61 @@ CONFIG_METADATA_2 = {
|
||||
"type": "satori",
|
||||
"enable": False,
|
||||
"satori_api_base_url": "http://localhost:5140/satori/v1",
|
||||
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
|
||||
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
|
||||
"satori_token": "",
|
||||
"satori_auto_reconnect": True,
|
||||
"satori_heartbeat_interval": 10,
|
||||
"satori_reconnect_delay": 5,
|
||||
},
|
||||
# "WebChat": {
|
||||
# "id": "webchat",
|
||||
# "type": "webchat",
|
||||
# "enable": False,
|
||||
# "webchat_link_path": "",
|
||||
# "webchat_present_type": "fullscreen",
|
||||
# },
|
||||
},
|
||||
"items": {
|
||||
# "webchat_link_path": {
|
||||
# "description": "链接路径",
|
||||
# "_special": "webchat_link_path",
|
||||
# "type": "string",
|
||||
# },
|
||||
# "webchat_present_type": {
|
||||
# "_special": "webchat_present_type",
|
||||
# "description": "展现形式",
|
||||
# "type": "string",
|
||||
# "options": ["fullscreen", "embedded"],
|
||||
# },
|
||||
"satori_api_base_url": {
|
||||
"description": "Satori API Base URL",
|
||||
"description": "Satori API 终结点",
|
||||
"type": "string",
|
||||
"hint": "The base URL for the Satori API.",
|
||||
"hint": "Satori API 的基础地址。",
|
||||
},
|
||||
"satori_endpoint": {
|
||||
"description": "Satori WebSocket Endpoint",
|
||||
"description": "Satori WebSocket 终结点",
|
||||
"type": "string",
|
||||
"hint": "The WebSocket endpoint for Satori events.",
|
||||
"hint": "Satori 事件的 WebSocket 端点。",
|
||||
},
|
||||
"satori_token": {
|
||||
"description": "Satori Token",
|
||||
"description": "Satori 令牌",
|
||||
"type": "string",
|
||||
"hint": "The token used for authenticating with the Satori API.",
|
||||
"hint": "用于 Satori API 身份验证的令牌。",
|
||||
},
|
||||
"satori_auto_reconnect": {
|
||||
"description": "Enable Auto Reconnect",
|
||||
"description": "启用自动重连",
|
||||
"type": "bool",
|
||||
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
|
||||
"hint": "断开连接时是否自动重新连接 WebSocket。",
|
||||
},
|
||||
"satori_heartbeat_interval": {
|
||||
"description": "Satori Heartbeat Interval",
|
||||
"description": "Satori 心跳间隔",
|
||||
"type": "int",
|
||||
"hint": "The interval (in seconds) for sending heartbeat messages.",
|
||||
"hint": "发送心跳消息的间隔(秒)。",
|
||||
},
|
||||
"satori_reconnect_delay": {
|
||||
"description": "Satori Reconnect Delay",
|
||||
"description": "Satori 重连延迟",
|
||||
"type": "int",
|
||||
"hint": "The delay (in seconds) before attempting to reconnect.",
|
||||
"hint": "尝试重新连接前的延迟时间(秒)。",
|
||||
},
|
||||
"slack_connection_mode": {
|
||||
"description": "Slack Connection Mode",
|
||||
@@ -336,6 +400,67 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||
},
|
||||
"misskey_instance_url": {
|
||||
"description": "Misskey 实例 URL",
|
||||
"type": "string",
|
||||
"hint": "例如 https://misskey.example,填写 Bot 账号所在的 Misskey 实例地址",
|
||||
},
|
||||
"misskey_token": {
|
||||
"description": "Misskey Access Token",
|
||||
"type": "string",
|
||||
"hint": "连接服务设置生成的 API 鉴权访问令牌(Access token)",
|
||||
},
|
||||
"misskey_default_visibility": {
|
||||
"description": "默认帖子可见性",
|
||||
"type": "string",
|
||||
"options": ["public", "home", "followers"],
|
||||
"hint": "机器人发帖时的默认可见性设置。public:公开,home:主页时间线,followers:仅关注者。",
|
||||
},
|
||||
"misskey_local_only": {
|
||||
"description": "仅限本站(不参与联合)",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
|
||||
},
|
||||
"misskey_enable_chat": {
|
||||
"description": "启用聊天消息响应",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
||||
},
|
||||
"misskey_enable_file_upload": {
|
||||
"description": "启用文件上传到 Misskey",
|
||||
"type": "bool",
|
||||
"hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。",
|
||||
},
|
||||
"misskey_allow_insecure_downloads": {
|
||||
"description": "允许不安全下载(禁用 SSL 验证)",
|
||||
"type": "bool",
|
||||
"hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。",
|
||||
},
|
||||
"misskey_download_timeout": {
|
||||
"description": "远端下载超时时间(秒)",
|
||||
"type": "int",
|
||||
"hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。",
|
||||
},
|
||||
"misskey_download_chunk_size": {
|
||||
"description": "流式下载分块大小(字节)",
|
||||
"type": "int",
|
||||
"hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。",
|
||||
},
|
||||
"misskey_max_download_bytes": {
|
||||
"description": "最大允许下载字节数(超出则中止)",
|
||||
"type": "int",
|
||||
"hint": "如果希望限制下载文件的最大大小以防止 OOM,请填写最大字节数;留空或 null 表示不限制。",
|
||||
},
|
||||
"misskey_upload_concurrency": {
|
||||
"description": "并发上传限制",
|
||||
"type": "int",
|
||||
"hint": "同时进行的文件上传任务上限(整数,默认 3)。",
|
||||
},
|
||||
"misskey_upload_folder": {
|
||||
"description": "上传到网盘的目标文件夹 ID",
|
||||
"type": "string",
|
||||
"hint": "可选:填写 Misskey 网盘中目标文件夹的 ID,上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。",
|
||||
},
|
||||
"telegram_command_register": {
|
||||
"description": "Telegram 命令注册",
|
||||
"type": "bool",
|
||||
@@ -387,24 +512,38 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
||||
},
|
||||
"ws_reverse_host": {
|
||||
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
|
||||
"description": "反向 Websocket 主机",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
||||
"hint": "AstrBot 将作为服务器端。",
|
||||
},
|
||||
"ws_reverse_port": {
|
||||
"description": "反向 Websocket 端口",
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"ws_reverse_token": {
|
||||
"description": "反向 Websocket Token",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
"hint": "反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
},
|
||||
"wecom_ai_bot_name": {
|
||||
"description": "企业微信智能机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填写正确,否则无法使用一些指令。",
|
||||
},
|
||||
"wecomaibot_init_respond_text": {
|
||||
"description": "企业微信智能机器人初始响应文本",
|
||||
"type": "string",
|
||||
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
|
||||
},
|
||||
"wecomaibot_friend_message_welcome_text": {
|
||||
"description": "企业微信智能机器人私聊欢迎语",
|
||||
"type": "string",
|
||||
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
},
|
||||
"discord_token": {
|
||||
"description": "Discord Bot Token",
|
||||
@@ -729,7 +868,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -775,6 +914,21 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
@@ -832,6 +986,18 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
"provider": "coze",
|
||||
"provider_type": "chat_completion",
|
||||
"type": "coze",
|
||||
"enable": True,
|
||||
"coze_api_key": "",
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
@@ -983,6 +1149,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
|
||||
"id": "dashscope_tts",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope_tts",
|
||||
@@ -1250,6 +1417,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "嵌入维度",
|
||||
"type": "int",
|
||||
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
||||
"_special": "get_embedding_dim",
|
||||
},
|
||||
"embedding_model": {
|
||||
"description": "嵌入模型",
|
||||
@@ -1362,11 +1530,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "服务订阅密钥",
|
||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
||||
},
|
||||
"dashscope_tts_voice": {
|
||||
"description": "语音合成模型",
|
||||
"type": "string",
|
||||
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
|
||||
},
|
||||
"dashscope_tts_voice": {"description": "音色", "type": "string"},
|
||||
"gm_resp_image_modal": {
|
||||
"description": "启用图片模态",
|
||||
"type": "bool",
|
||||
@@ -1698,6 +1862,26 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||
"obvious": True,
|
||||
},
|
||||
"coze_api_key": {
|
||||
"description": "Coze API Key",
|
||||
"type": "string",
|
||||
"hint": "Coze API 密钥,用于访问 Coze 服务。",
|
||||
},
|
||||
"bot_id": {
|
||||
"description": "Bot ID",
|
||||
"type": "string",
|
||||
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
|
||||
},
|
||||
"coze_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
|
||||
},
|
||||
"auto_save_history": {
|
||||
"description": "由 Coze 管理对话记录",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
@@ -1724,6 +1908,9 @@ CONFIG_METADATA_2 = {
|
||||
"identifier": {
|
||||
"type": "bool",
|
||||
},
|
||||
"group_name_display": {
|
||||
"type": "bool",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"type": "bool",
|
||||
},
|
||||
@@ -1752,6 +1939,10 @@ CONFIG_METADATA_2 = {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
"tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -1874,6 +2065,9 @@ CONFIG_METADATA_2 = {
|
||||
"default_kb_collection": {
|
||||
"type": "string",
|
||||
},
|
||||
"kb_names": {"type": "list", "items": {"type": "string"}},
|
||||
"kb_fusion_top_k": {"type": "int", "default": 20},
|
||||
"kb_final_top_k": {"type": "int", "default": 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1903,17 +2097,33 @@ CONFIG_METADATA_3 = {
|
||||
"_special": "select_provider",
|
||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||
},
|
||||
"provider_stt_settings.enable": {
|
||||
"description": "启用语音转文本",
|
||||
"type": "bool",
|
||||
"hint": "STT 总开关。",
|
||||
},
|
||||
"provider_stt_settings.provider_id": {
|
||||
"description": "语音转文本模型",
|
||||
"description": "默认语音转文本模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型。",
|
||||
"_special": "select_provider_stt",
|
||||
"condition": {
|
||||
"provider_stt_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.enable": {
|
||||
"description": "启用文本转语音",
|
||||
"type": "bool",
|
||||
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
|
||||
},
|
||||
"provider_tts_settings.provider_id": {
|
||||
"description": "文本转语音模型",
|
||||
"description": "默认文本转语音模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
|
||||
"_special": "select_provider_tts",
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.image_caption_prompt": {
|
||||
"description": "图片转述提示词",
|
||||
@@ -1936,10 +2146,22 @@ CONFIG_METADATA_3 = {
|
||||
"description": "知识库",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"default_kb_collection": {
|
||||
"description": "默认使用的知识库",
|
||||
"type": "string",
|
||||
"kb_names": {
|
||||
"description": "知识库列表",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"_special": "select_knowledgebase",
|
||||
"hint": "支持多选",
|
||||
},
|
||||
"kb_fusion_top_k": {
|
||||
"description": "融合检索结果数",
|
||||
"type": "int",
|
||||
"hint": "多个知识库检索结果融合后的返回结果数量",
|
||||
},
|
||||
"kb_final_top_k": {
|
||||
"description": "最终返回结果数",
|
||||
"type": "int",
|
||||
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1954,7 +2176,7 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.websearch_provider": {
|
||||
"description": "网页搜索提供商",
|
||||
"type": "string",
|
||||
"options": ["default", "tavily"],
|
||||
"options": ["default", "tavily", "baidu_ai_search"],
|
||||
},
|
||||
"provider_settings.websearch_tavily_key": {
|
||||
"description": "Tavily API Key",
|
||||
@@ -1965,6 +2187,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.websearch_provider": "tavily",
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_baidu_app_builder_key": {
|
||||
"description": "百度千帆智能云 APP Builder API Key",
|
||||
"type": "string",
|
||||
"hint": "参考:https://console.bce.baidu.com/iam/#/iam/apikey/list",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "baidu_ai_search",
|
||||
},
|
||||
},
|
||||
"provider_settings.web_search_link": {
|
||||
"description": "显示来源引用",
|
||||
"type": "bool",
|
||||
@@ -1983,6 +2213,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "用户识别",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.group_name_display": {
|
||||
"description": "显示群名称",
|
||||
"type": "bool",
|
||||
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
||||
},
|
||||
"provider_settings.datetime_system_prompt": {
|
||||
"description": "现实世界时间感知",
|
||||
"type": "bool",
|
||||
@@ -1995,6 +2230,10 @@ CONFIG_METADATA_3 = {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
"provider_settings.tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
},
|
||||
"provider_settings.streaming_response": {
|
||||
"description": "流式回复",
|
||||
"type": "bool",
|
||||
@@ -2016,12 +2255,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
},
|
||||
"provider_settings.prompt_prefix": {
|
||||
"description": "额外前缀提示词",
|
||||
"description": "用户提示词",
|
||||
"type": "string",
|
||||
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
|
||||
},
|
||||
"provider_settings.dual_output": {
|
||||
"provider_tts_settings.dual_output": {
|
||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||
"type": "bool",
|
||||
},
|
||||
@@ -2202,6 +2443,32 @@ CONFIG_METADATA_3 = {
|
||||
"description": "用户权限不足时是否回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.lark.pre_ack_emoji.enable": {
|
||||
"description": "[飞书] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.lark.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(飞书表情枚举名)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "表情枚举名参考:https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
|
||||
"condition": {
|
||||
"platform_specific.lark.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": {
|
||||
"description": "[Telegram] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.telegram.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(Unicode)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "Telegram 仅支持固定反应集合,参考:https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
|
||||
"condition": {
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
||||
|
||||
import json
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Callable, Awaitable
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation, ConversationV2
|
||||
|
||||
@@ -20,6 +20,38 @@ class ConversationManager:
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
|
||||
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
||||
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
||||
|
||||
def register_on_session_deleted(
|
||||
self, callback: Callable[[str], Awaitable[None]]
|
||||
) -> None:
|
||||
"""注册会话删除回调函数
|
||||
|
||||
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
||||
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
||||
"""
|
||||
self._on_session_deleted_callbacks.append(callback)
|
||||
|
||||
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
||||
"""触发会话删除回调
|
||||
|
||||
Args:
|
||||
unified_msg_origin: 会话ID
|
||||
"""
|
||||
for callback in self._on_session_deleted_callbacks:
|
||||
try:
|
||||
await callback(unified_msg_origin)
|
||||
except Exception as e:
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.error(
|
||||
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
|
||||
)
|
||||
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||
created_at = int(conv_v2.created_at.timestamp())
|
||||
@@ -87,17 +119,28 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
f = False
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
f = True
|
||||
if conversation_id:
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if curr_cid == conversation_id:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
||||
"""删除会话的所有对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
"""
|
||||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
# 触发会话删除回调(级联清理)
|
||||
await self._trigger_session_deleted(unified_msg_origin)
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config, html_renderer
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -26,14 +25,17 @@ from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
@@ -84,11 +86,21 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
await html_renderer.initialize()
|
||||
|
||||
# 初始化 UMOP 配置路由器
|
||||
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
default_config=self.astrbot_config, sp=sp
|
||||
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
|
||||
)
|
||||
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
try:
|
||||
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
@@ -110,6 +122,9 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化平台消息历史管理器
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
@@ -121,6 +136,7 @@ class AstrBotCoreLifecycle:
|
||||
self.platform_message_history_manager,
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
self.kb_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -132,8 +148,9 @@ class AstrBotCoreLifecycle:
|
||||
# 根据配置实例化各个 Provider
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
# 初始化更新器
|
||||
@@ -148,7 +165,7 @@ class AstrBotCoreLifecycle:
|
||||
self.start_time = int(time.time())
|
||||
|
||||
# 初始化当前任务列表
|
||||
self.curr_tasks: List[asyncio.Task] = []
|
||||
self.curr_tasks: list[asyncio.Task] = []
|
||||
|
||||
# 根据配置实例化各个平台适配器
|
||||
await self.platform_manager.initialize()
|
||||
@@ -233,6 +250,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
|
||||
# 再次遍历curr_tasks等待每个任务真正结束
|
||||
@@ -248,12 +266,13 @@ class AstrBotCoreLifecycle:
|
||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
threading.Thread(
|
||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||
).start()
|
||||
|
||||
def load_platform(self) -> List[asyncio.Task]:
|
||||
def load_platform(self) -> list[asyncio.Task]:
|
||||
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||
tasks = []
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
|
||||
@@ -154,12 +154,17 @@ class BaseDatabase(abc.ABC):
|
||||
"""Delete a conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
"""Delete all conversations for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
content: dict,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -282,3 +287,14 @@ class BaseDatabase(abc.ABC):
|
||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||
# """Get all LLM messages for a specific conversation."""
|
||||
# ...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search_query: str | None = None,
|
||||
platform: str | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
||||
...
|
||||
|
||||
44
astrbot/core/db/migration/migra_45_to_46.py
Normal file
44
astrbot/core/db/migration/migra_45_to_46.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
|
||||
|
||||
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
|
||||
abconf_data = acm.abconf_data
|
||||
|
||||
if not isinstance(abconf_data, dict):
|
||||
# should be unreachable
|
||||
logger.warning(
|
||||
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
|
||||
)
|
||||
return
|
||||
|
||||
# 如果任何一项带有 umop,则说明需要迁移
|
||||
need_migration = False
|
||||
for conf_id, conf_info in abconf_data.items():
|
||||
if isinstance(conf_info, dict) and "umop" in conf_info:
|
||||
need_migration = True
|
||||
break
|
||||
|
||||
if not need_migration:
|
||||
return
|
||||
|
||||
logger.info("Starting migration from version 4.5 to 4.6")
|
||||
|
||||
# extract umo->conf_id mapping
|
||||
umo_to_conf_id = {}
|
||||
for conf_id, conf_info in abconf_data.items():
|
||||
if isinstance(conf_info, dict) and "umop" in conf_info:
|
||||
umop_ls = conf_info.pop("umop")
|
||||
if not isinstance(umop_ls, list):
|
||||
continue
|
||||
for umo in umop_ls:
|
||||
if isinstance(umo, str) and umo not in umo_to_conf_id:
|
||||
umo_to_conf_id[umo] = conf_id
|
||||
|
||||
# update the abconf data
|
||||
await sp.global_put("abconf_mapping", abconf_data)
|
||||
# update the umop config router
|
||||
await ucr.update_routing_data(umo_to_conf_id)
|
||||
|
||||
logger.info("Migration from version 45 to 46 completed successfully")
|
||||
@@ -75,7 +75,9 @@ class Persona(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "personas"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
@@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
@@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
|
||||
@@ -15,9 +15,8 @@ from astrbot.core.db.po import (
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
from sqlalchemy import select, update, delete, text
|
||||
from sqlmodel import select, update, delete, text, func, or_, desc, col
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
@@ -33,6 +32,12 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Initialize the database by creating tables if they do not exist."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
# ====
|
||||
@@ -41,10 +46,10 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime = None,
|
||||
platform_id,
|
||||
platform_type,
|
||||
count=1,
|
||||
timestamp=None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -75,7 +80,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
|
||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
||||
PlatformStat
|
||||
)
|
||||
)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
@@ -95,7 +102,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
return result.scalars().all()
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ====
|
||||
# Conversation Management
|
||||
@@ -111,7 +118,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
if platform_id:
|
||||
query = query.where(ConversationV2.platform_id == platform_id)
|
||||
# order by
|
||||
query = query.order_by(ConversationV2.created_at.desc())
|
||||
query = query.order_by(desc(ConversationV2.created_at))
|
||||
result = await session.execute(query)
|
||||
|
||||
return result.scalars().all()
|
||||
@@ -129,7 +136,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
offset = (page - 1) * page_size
|
||||
result = await session.execute(
|
||||
select(ConversationV2)
|
||||
.order_by(ConversationV2.created_at.desc())
|
||||
.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
@@ -150,11 +157,26 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
if platform_ids:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.platform_id.in_(platform_ids)
|
||||
col(ConversationV2.platform_id).in_(platform_ids)
|
||||
)
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
base_query = base_query.where(
|
||||
ConversationV2.title.ilike(f"%{search_query}%")
|
||||
or_(
|
||||
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||
)
|
||||
)
|
||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||
for msg_type in kwargs["message_types"]:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
||||
)
|
||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
||||
)
|
||||
|
||||
# Get total count matching the filters
|
||||
@@ -165,7 +187,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
result_query = (
|
||||
base_query.order_by(ConversationV2.created_at.desc())
|
||||
base_query.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
@@ -211,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(ConversationV2).where(
|
||||
ConversationV2.conversation_id == cid
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
values = {}
|
||||
if title is not None:
|
||||
@@ -231,9 +253,126 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
|
||||
)
|
||||
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page=1,
|
||||
page_size=20,
|
||||
search_query=None,
|
||||
platform=None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
base_query = (
|
||||
select(
|
||||
col(Preference.scope_id).label("session_id"),
|
||||
func.json_extract(Preference.value, "$.val").label(
|
||||
"conversation_id"
|
||||
), # type: ignore
|
||||
col(ConversationV2.persona_id).label("persona_id"),
|
||||
col(ConversationV2.title).label("title"),
|
||||
col(Persona.persona_id).label("persona_name"),
|
||||
)
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 搜索筛选
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
# 平台筛选
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
base_query = base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
)
|
||||
|
||||
# 排序
|
||||
base_query = base_query.order_by(Preference.scope_id)
|
||||
|
||||
# 分页结果
|
||||
result_query = base_query.offset(offset).limit(page_size)
|
||||
result = await session.execute(result_query)
|
||||
rows = result.fetchall()
|
||||
|
||||
# 查询总数(应用相同的筛选条件)
|
||||
count_base_query = (
|
||||
select(func.count(col(Preference.scope_id)))
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 应用相同的搜索和平台筛选条件到计数查询
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
count_base_query = count_base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
count_base_query = count_base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
)
|
||||
|
||||
total_result = await session.execute(count_base_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
sessions_data = [
|
||||
{
|
||||
"session_id": row.session_id,
|
||||
"conversation_id": row.conversation_id,
|
||||
"persona_id": row.persona_id,
|
||||
"title": row.title,
|
||||
"persona_name": row.persona_name,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return sessions_data, total
|
||||
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id,
|
||||
@@ -267,9 +406,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
cutoff_time = now - timedelta(seconds=offset_sec)
|
||||
await session.execute(
|
||||
delete(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
PlatformMessageHistory.created_at < cutoff_time,
|
||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||
col(PlatformMessageHistory.user_id) == user_id,
|
||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -286,7 +425,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
)
|
||||
.order_by(PlatformMessageHistory.created_at.desc())
|
||||
.order_by(desc(PlatformMessageHistory.created_at))
|
||||
)
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
@@ -308,7 +447,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Get an attachment by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(Attachment.id == attachment_id)
|
||||
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -351,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(Persona).where(Persona.persona_id == persona_id)
|
||||
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
values = {}
|
||||
if system_prompt is not None:
|
||||
values["system_prompt"] = system_prompt
|
||||
@@ -371,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Persona).where(Persona.persona_id == persona_id)
|
||||
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||
@@ -426,9 +565,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
col(Preference.key) == key,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
@@ -440,7 +579,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope, Preference.scope_id == scope_id
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
@@ -467,7 +607,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
DeprecatedPlatformStat(
|
||||
name=data.platform_id,
|
||||
count=data.count,
|
||||
timestamp=data.timestamp.timestamp(),
|
||||
timestamp=int(data.timestamp.timestamp()),
|
||||
)
|
||||
)
|
||||
return deprecated_stats
|
||||
@@ -525,7 +665,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
DeprecatedPlatformStat(
|
||||
name=platform_id,
|
||||
count=count,
|
||||
timestamp=start_time.timestamp(),
|
||||
timestamp=int(start_time.timestamp()),
|
||||
)
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
@@ -16,14 +16,42 @@ class BaseVecDB:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
async def insert_batch(
|
||||
self,
|
||||
contents: list[str],
|
||||
metadatas: list[dict] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> int:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
@@ -44,3 +72,6 @@ class BaseVecDB:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self): ...
|
||||
|
||||
@@ -1,59 +1,224 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import Text, Column
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class BaseDocModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class Document(BaseDocModel, table=True):
|
||||
"""SQLModel for documents table."""
|
||||
|
||||
__tablename__ = "documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
doc_id: str = Field(nullable=False)
|
||||
text: str = Field(nullable=False)
|
||||
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
|
||||
created_at: datetime | None = Field(default=None)
|
||||
updated_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.engine: AsyncEngine | None = None
|
||||
self.async_session_maker: sessionmaker | None = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
await self.connect()
|
||||
async with self.engine.begin() as conn: # type: ignore
|
||||
# Create tables using SQLModel
|
||||
await conn.run_sync(BaseDocModel.metadata.create_all)
|
||||
|
||||
try:
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN user_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
|
||||
)
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
|
||||
)
|
||||
)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
await conn.commit()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
if self.engine is None:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.async_session_maker = sessionmaker(
|
||||
self.engine, # type: ignore
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
) # type: ignore
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
@asynccontextmanager
|
||||
async def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
async with self.async_session_maker() as session: # type: ignore
|
||||
yield session
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
metadata_filters: dict,
|
||||
ids: list | None = None,
|
||||
offset: int | None = 0,
|
||||
limit: int | None = 100,
|
||||
) -> list[dict]:
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
ids (list | None): Optional list of document IDs to filter.
|
||||
offset (int | None): Offset for pagination.
|
||||
limit (int | None): Limit for pagination.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
list: The list of documents that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
if self.engine is None:
|
||||
logger.warning(
|
||||
"Database connection is not initialized, returning empty result"
|
||||
)
|
||||
return []
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
async with self.get_session() as session:
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
if ids is not None and len(ids) > 0:
|
||||
valid_ids = [int(i) for i in ids if i != -1]
|
||||
if valid_ids:
|
||||
query = query.where(col(Document.id).in_(valid_ids))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if offset is not None:
|
||||
query = query.offset(offset)
|
||||
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
return [self._document_to_dict(doc) for doc in documents]
|
||||
|
||||
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
|
||||
"""Insert a single document and return its integer ID.
|
||||
|
||||
Args:
|
||||
doc_id (str): The document ID (UUID string).
|
||||
text (str): The document text.
|
||||
metadata (dict): The document metadata.
|
||||
|
||||
Returns:
|
||||
int: The integer ID of the inserted document.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
session.add(document)
|
||||
await session.flush() # Flush to get the ID
|
||||
return document.id # type: ignore
|
||||
|
||||
async def insert_documents_batch(
|
||||
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
|
||||
) -> list[int]:
|
||||
"""Batch insert documents and return their integer IDs.
|
||||
|
||||
Args:
|
||||
doc_ids (list[str]): List of document IDs (UUID strings).
|
||||
texts (list[str]): List of document texts.
|
||||
metadatas (list[dict]): List of document metadata.
|
||||
|
||||
Returns:
|
||||
list[int]: List of integer IDs of the inserted documents.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
import json
|
||||
|
||||
documents = []
|
||||
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
|
||||
await session.flush() # Flush to get all IDs
|
||||
return [doc.id for doc in documents] # type: ignore
|
||||
|
||||
async def delete_document_by_doc_id(self, doc_id: str):
|
||||
"""Delete a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
await session.delete(document)
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
@@ -62,28 +227,91 @@ class DocumentStorage:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
dict: The document data or None if not found.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
return self._document_to_dict(document)
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
"""Update a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
document.text = new_text
|
||||
document.updated_at = datetime.now()
|
||||
session.add(document)
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""Delete documents by their metadata filters.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
"""
|
||||
if self.engine is None:
|
||||
logger.warning(
|
||||
"Database connection is not initialized, skipping delete operation"
|
||||
)
|
||||
await self.connection.commit()
|
||||
return
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
for doc in documents:
|
||||
await session.delete(doc)
|
||||
|
||||
async def count_documents(self, metadata_filters: dict | None = None) -> int:
|
||||
"""Count documents in the database.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict | None): Metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
int: The count of documents.
|
||||
"""
|
||||
if self.engine is None:
|
||||
logger.warning("Database connection is not initialized, returning 0")
|
||||
return 0
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(func.count(col(Document.id)))
|
||||
|
||||
if metadata_filters:
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
@@ -91,11 +319,38 @@ class DocumentStorage:
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = text(
|
||||
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
|
||||
)
|
||||
result = await session.execute(query)
|
||||
rows = result.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def _document_to_dict(self, document: Document) -> dict:
|
||||
"""Convert a Document model to a dictionary.
|
||||
|
||||
Args:
|
||||
document (Document): The document to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": document.id,
|
||||
"doc_id": document.doc_id,
|
||||
"text": document.text,
|
||||
"metadata": document.metadata_,
|
||||
"created_at": document.created_at.isoformat()
|
||||
if isinstance(document.created_at, datetime)
|
||||
else document.created_at,
|
||||
"updated_at": document.updated_at.isoformat()
|
||||
if isinstance(document.updated_at, datetime)
|
||||
else document.updated_at,
|
||||
}
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
@@ -104,6 +359,8 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
|
||||
Note: This method is kept for backward compatibility but is no longer used internally.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
@@ -116,6 +373,7 @@ class DocumentStorage:
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
if self.engine:
|
||||
await self.engine.dispose()
|
||||
self.engine = None
|
||||
self.async_session_maker = None
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
def __init__(self, dimension: int, path: str | None = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
@@ -18,7 +18,6 @@ class EmbeddingStorage:
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
@@ -29,12 +28,29 @@ class EmbeddingStorage:
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
|
||||
"""批量插入向量
|
||||
|
||||
Args:
|
||||
vectors (np.ndarray): 要插入的向量数组
|
||||
ids (list[int]): 向量的ID列表
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
if vectors.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
|
||||
)
|
||||
self.index.add_with_ids(vectors, np.array(ids))
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
@@ -46,10 +62,22 @@ class EmbeddingStorage:
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def delete(self, ids: list[int]):
|
||||
"""删除向量
|
||||
|
||||
Args:
|
||||
ids (list[int]): 要删除的向量ID列表
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
id_array = np.array(ids, dtype=np.int64)
|
||||
self.index.remove_ids(id_array)
|
||||
await self.save_index()
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
@@ -44,18 +45,56 @@ class FaissVecDB(BaseVecDB):
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
# 使用 DocumentStorage 的方法插入文档
|
||||
int_id = await self.document_storage.insert_document(str_id, content, metadata)
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def insert_batch(
|
||||
self,
|
||||
contents: list[str],
|
||||
metadatas: list[dict] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
metadatas = metadatas or [{} for _ in contents]
|
||||
ids = ids or [str(uuid.uuid4()) for _ in contents]
|
||||
|
||||
start = time.time()
|
||||
logger.debug(f"Generating embeddings for {len(contents)} contents...")
|
||||
vectors = await self.embedding_provider.get_embeddings_batch(
|
||||
contents,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
end = time.time()
|
||||
logger.debug(
|
||||
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
|
||||
)
|
||||
|
||||
# 使用 DocumentStorage 的批量插入方法
|
||||
int_ids = await self.document_storage.insert_documents_batch(
|
||||
ids, contents, metadatas
|
||||
)
|
||||
|
||||
# 批量插入向量到 FAISS
|
||||
vectors_array = np.array(vectors).astype("float32")
|
||||
await self.embedding_storage.insert_batch(vectors_array, int_ids)
|
||||
return int_ids
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
@@ -119,23 +158,42 @@ class FaissVecDB(BaseVecDB):
|
||||
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
async def delete(self, doc_id: str):
|
||||
"""
|
||||
删除一条文档
|
||||
删除一条文档块(chunk)
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
# 获得对应的 int id
|
||||
result = await self.document_storage.get_document_by_doc_id(doc_id)
|
||||
int_id = result["id"] if result else None
|
||||
if int_id is None:
|
||||
return
|
||||
|
||||
# 使用 DocumentStorage 的删除方法
|
||||
await self.document_storage.delete_document_by_doc_id(doc_id)
|
||||
await self.embedding_storage.delete([int_id])
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
async def count_documents(self, metadata_filter: dict | None = None) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
|
||||
Args:
|
||||
metadata_filter (dict | None): 元数据过滤器
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
count = await self.document_storage.count_documents(
|
||||
metadata_filters=metadata_filter or {}
|
||||
)
|
||||
return count
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""
|
||||
根据元数据过滤器删除文档
|
||||
"""
|
||||
docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters, offset=None, limit=None
|
||||
)
|
||||
doc_ids: list[int] = [doc["id"] for doc in docs]
|
||||
await self.embedding_storage.delete(doc_ids)
|
||||
await self.document_storage.delete_documents(metadata_filters=metadata_filters)
|
||||
|
||||
@@ -23,7 +23,12 @@ class FileTokenService:
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
||||
async def check_token_expired(self, file_token: str) -> bool:
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
return file_token not in self.staged_files
|
||||
|
||||
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
|
||||
"""向令牌服务注册一个文件。
|
||||
|
||||
Args:
|
||||
|
||||
@@ -41,10 +41,13 @@ class InitialLoader:
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
|
||||
)
|
||||
task = asyncio.gather(
|
||||
core_task, self.dashboard_server.run()
|
||||
) # 启动核心任务和仪表板服务器
|
||||
|
||||
coro = self.dashboard_server.run()
|
||||
if coro:
|
||||
# 启动核心任务和仪表板服务器
|
||||
task = asyncio.gather(core_task, coro)
|
||||
else:
|
||||
task = core_task
|
||||
try:
|
||||
await task # 整个AstrBot在这里运行
|
||||
except asyncio.CancelledError:
|
||||
|
||||
11
astrbot/core/knowledge_base/chunking/__init__.py
Normal file
11
astrbot/core/knowledge_base/chunking/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
文档分块模块
|
||||
"""
|
||||
|
||||
from .base import BaseChunker
|
||||
from .fixed_size import FixedSizeChunker
|
||||
|
||||
__all__ = [
|
||||
"BaseChunker",
|
||||
"FixedSizeChunker",
|
||||
]
|
||||
24
astrbot/core/knowledge_base/chunking/base.py
Normal file
24
astrbot/core/knowledge_base/chunking/base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""文档分块器基类
|
||||
|
||||
定义了文档分块处理的抽象接口。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseChunker(ABC):
|
||||
"""分块器基类
|
||||
|
||||
所有分块器都应该继承此类并实现 chunk 方法。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||
"""将文本分块
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
"""
|
||||
57
astrbot/core/knowledge_base/chunking/fixed_size.py
Normal file
57
astrbot/core/knowledge_base/chunking/fixed_size.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""固定大小分块器
|
||||
|
||||
按照固定的字符数将文本分块,支持重叠区域。
|
||||
"""
|
||||
|
||||
from .base import BaseChunker
|
||||
|
||||
|
||||
class FixedSizeChunker(BaseChunker):
|
||||
"""固定大小分块器
|
||||
|
||||
按照固定的字符数分块,并支持块之间的重叠。
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
||||
"""初始化分块器
|
||||
|
||||
Args:
|
||||
chunk_size: 块的大小(字符数)
|
||||
chunk_overlap: 块之间的重叠字符数
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||
"""固定大小分块
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
chunk_size: 每个文本块的最大大小
|
||||
chunk_overlap: 每个文本块之间的重叠部分大小
|
||||
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
"""
|
||||
chunk_size = kwargs.get("chunk_size", self.chunk_size)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_len = len(text)
|
||||
|
||||
while start < text_len:
|
||||
end = start + chunk_size
|
||||
chunk = text[start:end]
|
||||
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 移动窗口,保留重叠部分
|
||||
start = end - chunk_overlap
|
||||
|
||||
# 防止无限循环: 如果重叠过大,直接移到end
|
||||
if start >= end or chunk_overlap >= chunk_size:
|
||||
start = end
|
||||
|
||||
return chunks
|
||||
155
astrbot/core/knowledge_base/chunking/recursive.py
Normal file
155
astrbot/core/knowledge_base/chunking/recursive.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from collections.abc import Callable
|
||||
from .base import BaseChunker
|
||||
|
||||
|
||||
class RecursiveCharacterChunker(BaseChunker):
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 100,
|
||||
length_function: Callable[[str], int] = len,
|
||||
is_separator_regex: bool = False,
|
||||
separators: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
初始化递归字符文本分割器
|
||||
|
||||
Args:
|
||||
chunk_size: 每个文本块的最大大小
|
||||
chunk_overlap: 每个文本块之间的重叠部分大小
|
||||
length_function: 计算文本长度的函数
|
||||
is_separator_regex: 分隔符是否为正则表达式
|
||||
separators: 用于分割文本的分隔符列表,按优先级排序
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.length_function = length_function
|
||||
self.is_separator_regex = is_separator_regex
|
||||
|
||||
# 默认分隔符列表,按优先级从高到低
|
||||
self.separators = separators or [
|
||||
"\n\n", # 段落
|
||||
"\n", # 换行
|
||||
"。", # 中文句子
|
||||
",", # 中文逗号
|
||||
". ", # 句子
|
||||
", ", # 逗号分隔
|
||||
" ", # 单词
|
||||
"", # 字符
|
||||
]
|
||||
|
||||
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||
"""
|
||||
递归地将文本分割成块
|
||||
|
||||
Args:
|
||||
text: 要分割的文本
|
||||
chunk_size: 每个文本块的最大大小
|
||||
chunk_overlap: 每个文本块之间的重叠部分大小
|
||||
|
||||
Returns:
|
||||
分割后的文本块列表
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
|
||||
chunk_size = kwargs.get("chunk_size", self.chunk_size)
|
||||
|
||||
text_length = self.length_function(text)
|
||||
if text_length <= chunk_size:
|
||||
return [text]
|
||||
|
||||
for separator in self.separators:
|
||||
if separator == "":
|
||||
return self._split_by_character(text, chunk_size, overlap)
|
||||
|
||||
if separator in text:
|
||||
splits = text.split(separator)
|
||||
# 重新添加分隔符(除了最后一个片段)
|
||||
splits = [s + separator for s in splits[:-1]] + [splits[-1]]
|
||||
splits = [s for s in splits if s]
|
||||
if len(splits) == 1:
|
||||
continue
|
||||
|
||||
# 递归合并分割后的文本块
|
||||
final_chunks = []
|
||||
current_chunk = []
|
||||
current_chunk_length = 0
|
||||
|
||||
for split in splits:
|
||||
split_length = self.length_function(split)
|
||||
|
||||
# 如果单个分割部分已经超过了chunk_size,需要递归分割
|
||||
if split_length > chunk_size:
|
||||
# 先处理当前积累的块
|
||||
if current_chunk:
|
||||
combined_text = "".join(current_chunk)
|
||||
final_chunks.extend(
|
||||
await self.chunk(
|
||||
combined_text,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=overlap,
|
||||
)
|
||||
)
|
||||
current_chunk = []
|
||||
current_chunk_length = 0
|
||||
|
||||
# 递归分割过大的部分
|
||||
final_chunks.extend(
|
||||
await self.chunk(
|
||||
split, chunk_size=chunk_size, chunk_overlap=overlap
|
||||
)
|
||||
)
|
||||
# 如果添加这部分会使当前块超过chunk_size
|
||||
elif current_chunk_length + split_length > chunk_size:
|
||||
# 合并当前块并添加到结果中
|
||||
combined_text = "".join(current_chunk)
|
||||
final_chunks.append(combined_text)
|
||||
|
||||
# 处理重叠部分
|
||||
overlap_start = max(0, len(combined_text) - overlap)
|
||||
if overlap_start > 0:
|
||||
overlap_text = combined_text[overlap_start:]
|
||||
current_chunk = [overlap_text, split]
|
||||
current_chunk_length = (
|
||||
self.length_function(overlap_text) + split_length
|
||||
)
|
||||
else:
|
||||
current_chunk = [split]
|
||||
current_chunk_length = split_length
|
||||
else:
|
||||
# 添加到当前块
|
||||
current_chunk.append(split)
|
||||
current_chunk_length += split_length
|
||||
|
||||
# 处理剩余的块
|
||||
if current_chunk:
|
||||
final_chunks.append("".join(current_chunk))
|
||||
|
||||
return final_chunks
|
||||
|
||||
return [text]
|
||||
|
||||
def _split_by_character(
|
||||
self, text: str, chunk_size: int | None = None, overlap: int | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
按字符级别分割文本
|
||||
|
||||
Args:
|
||||
text: 要分割的文本
|
||||
|
||||
Returns:
|
||||
分割后的文本块列表
|
||||
"""
|
||||
chunk_size = chunk_size or self.chunk_size
|
||||
overlap = overlap or self.chunk_overlap
|
||||
result = []
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
end = min(i + chunk_size, len(text))
|
||||
result.append(text[i:end])
|
||||
if end == len(text):
|
||||
break
|
||||
|
||||
return result
|
||||
299
astrbot/core/knowledge_base/kb_db_sqlite.py
Normal file
299
astrbot/core/knowledge_base/kb_db_sqlite.py
Normal file
@@ -0,0 +1,299 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import col, desc
|
||||
from sqlalchemy import text, func, select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
BaseKBModel,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
"""初始化知识库数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# 确保目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建异步引擎
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""获取数据库会话
|
||||
|
||||
用法:
|
||||
async with kb_db.get_db() as session:
|
||||
# 执行数据库操作
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(BaseKBModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
self.inited = True
|
||||
|
||||
async def migrate_to_v1(self) -> None:
|
||||
"""执行知识库数据库 v1 迁移
|
||||
|
||||
创建所有必要的索引以优化查询性能
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
# 创建知识库表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(desc(KnowledgeBase.created_at))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_kbs(self) -> int:
|
||||
"""统计知识库数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KnowledgeBase.id)))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(desc(KBDocument.created_at))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KBDocument.id))).where(
|
||||
col(KBDocument.kb_id) == kb_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument, KnowledgeBase)
|
||||
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
|
||||
.where(col(KBDocument.doc_id) == doc_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return {
|
||||
"document": row[0],
|
||||
"knowledge_base": row[1],
|
||||
}
|
||||
|
||||
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
|
||||
"""删除单个文档及其相关数据"""
|
||||
# 在知识库表中删除
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# 删除文档记录
|
||||
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
await session.execute(delete_stmt)
|
||||
await session.commit()
|
||||
|
||||
# 在 vec db 中删除相关向量
|
||||
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
|
||||
"""更新知识库统计信息"""
|
||||
chunk_cnt = await vec_db.count_documents()
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
update_stmt = (
|
||||
update(KnowledgeBase)
|
||||
.where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
.values(
|
||||
doc_count=select(func.count(col(KBDocument.id)))
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.scalar_subquery(),
|
||||
chunk_count=chunk_cnt,
|
||||
)
|
||||
)
|
||||
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
348
astrbot/core/knowledge_base/kb_helper.py
Normal file
348
astrbot/core/knowledge_base/kb_helper.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import uuid
|
||||
import aiofiles
|
||||
import json
|
||||
from pathlib import Path
|
||||
from .models import KnowledgeBase, KBDocument, KBMedia
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from .parsers.util import select_parser
|
||||
from .chunking.base import BaseChunker
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class KBHelper:
|
||||
vec_db: BaseVecDB
|
||||
kb: KnowledgeBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
kb: KnowledgeBase,
|
||||
provider_manager: ProviderManager,
|
||||
kb_root_dir: str,
|
||||
chunker: BaseChunker,
|
||||
):
|
||||
self.kb_db = kb_db
|
||||
self.kb = kb
|
||||
self.prov_mgr = provider_manager
|
||||
self.kb_root_dir = kb_root_dir
|
||||
self.chunker = chunker
|
||||
|
||||
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
|
||||
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
|
||||
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
|
||||
|
||||
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def initialize(self):
|
||||
await self._ensure_vec_db()
|
||||
|
||||
async def get_ep(self) -> EmbeddingProvider:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.embedding_provider_id
|
||||
) # type: ignore
|
||||
if not ep:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
|
||||
)
|
||||
return ep
|
||||
|
||||
async def get_rp(self) -> RerankProvider | None:
|
||||
if not self.kb.rerank_provider_id:
|
||||
return None
|
||||
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.rerank_provider_id
|
||||
) # type: ignore
|
||||
if not rp:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
|
||||
)
|
||||
return rp
|
||||
|
||||
async def _ensure_vec_db(self) -> FaissVecDB:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
|
||||
ep = await self.get_ep()
|
||||
rp = await self.get_rp()
|
||||
|
||||
vec_db = FaissVecDB(
|
||||
doc_store_path=str(self.kb_dir / "doc.db"),
|
||||
index_store_path=str(self.kb_dir / "index.faiss"),
|
||||
embedding_provider=ep,
|
||||
rerank_provider=rp,
|
||||
)
|
||||
await vec_db.initialize()
|
||||
self.vec_db = vec_db
|
||||
return vec_db
|
||||
|
||||
async def delete_vec_db(self):
|
||||
"""删除知识库的向量数据库和所有相关文件"""
|
||||
import shutil
|
||||
|
||||
await self.terminate()
|
||||
if self.kb_dir.exists():
|
||||
shutil.rmtree(self.kb_dir)
|
||||
|
||||
async def terminate(self):
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
流程:
|
||||
1. 保存原始文件
|
||||
2. 解析文档内容
|
||||
3. 提取多媒体资源
|
||||
4. 分块处理
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
||||
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
|
||||
- current: 当前进度
|
||||
- total: 总数
|
||||
"""
|
||||
await self._ensure_vec_db()
|
||||
doc_id = str(uuid.uuid4())
|
||||
media_paths: list[Path] = []
|
||||
|
||||
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
||||
# async with aiofiles.open(file_path, "wb") as f:
|
||||
# await f.write(file_content)
|
||||
|
||||
try:
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = await select_parser(f".{file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
contents = []
|
||||
metadatas = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
contents.append(chunk_text)
|
||||
metadatas.append(
|
||||
{
|
||||
"kb_id": self.kb.kb_id,
|
||||
"kb_doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
}
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 100, 100)
|
||||
|
||||
# 阶段3: 生成向量(带进度回调)
|
||||
async def embedding_progress_callback(current, total):
|
||||
if progress_callback:
|
||||
await progress_callback("embedding", current, total)
|
||||
|
||||
await self.vec_db.insert_batch(
|
||||
contents=contents,
|
||||
metadatas=metadatas,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=embedding_progress_callback,
|
||||
)
|
||||
|
||||
# 保存文档的元数据
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
# file_path=str(file_path),
|
||||
file_path="",
|
||||
chunk_count=len(chunks_text),
|
||||
media_count=0,
|
||||
)
|
||||
async with self.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
for media in saved_media:
|
||||
session.add(media)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
|
||||
await self.refresh_kb()
|
||||
await self.refresh_document(doc_id)
|
||||
return doc
|
||||
except Exception as e:
|
||||
logger.error(f"上传文档失败: {e}")
|
||||
# if file_path.exists():
|
||||
# file_path.unlink()
|
||||
|
||||
for media_path in media_paths:
|
||||
try:
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as me:
|
||||
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||
|
||||
raise e
|
||||
|
||||
async def list_documents(
|
||||
self, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
|
||||
return docs
|
||||
|
||||
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||
"""获取单个文档"""
|
||||
doc = await self.kb_db.get_document_by_id(doc_id)
|
||||
return doc
|
||||
|
||||
async def delete_document(self, doc_id: str):
|
||||
"""删除单个文档及其相关数据"""
|
||||
await self.kb_db.delete_document_by_id(
|
||||
doc_id=doc_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.kb_db.update_kb_stats(
|
||||
kb_id=self.kb.kb_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.refresh_kb()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str, doc_id: str):
|
||||
"""删除单个文本块及其相关数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
await vec_db.delete(chunk_id)
|
||||
await self.kb_db.update_kb_stats(
|
||||
kb_id=self.kb.kb_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.refresh_kb()
|
||||
await self.refresh_document(doc_id)
|
||||
|
||||
async def refresh_kb(self):
|
||||
if self.kb:
|
||||
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
|
||||
if kb:
|
||||
self.kb = kb
|
||||
|
||||
async def refresh_document(self, doc_id: str) -> None:
|
||||
"""更新文档的元数据"""
|
||||
doc = await self.get_document(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
|
||||
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
|
||||
doc.chunk_count = chunk_count
|
||||
async with self.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
await session.commit()
|
||||
await session.refresh(doc)
|
||||
|
||||
async def get_chunks_by_doc_id(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[dict]:
|
||||
"""获取文档的所有块及其元数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
chunks = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
|
||||
)
|
||||
result = []
|
||||
for chunk in chunks:
|
||||
chunk_md = json.loads(chunk["metadata"])
|
||||
result.append(
|
||||
{
|
||||
"chunk_id": chunk["doc_id"],
|
||||
"doc_id": chunk_md["kb_doc_id"],
|
||||
"kb_id": chunk_md["kb_id"],
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"content": chunk["text"],
|
||||
"char_count": len(chunk["text"]),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
|
||||
"""获取文档的块数量"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
|
||||
return count
|
||||
|
||||
async def _save_media(
|
||||
self,
|
||||
doc_id: str,
|
||||
media_type: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
mime_type: str,
|
||||
) -> KBMedia:
|
||||
"""保存多媒体资源"""
|
||||
media_id = str(uuid.uuid4())
|
||||
ext = Path(file_name).suffix
|
||||
|
||||
# 保存文件
|
||||
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
media = KBMedia(
|
||||
media_id=media_id,
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
media_type=media_type,
|
||||
file_name=file_name,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
return media
|
||||
287
astrbot/core/knowledge_base/kb_mgr.py
Normal file
287
astrbot/core/knowledge_base/kb_mgr.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
|
||||
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
# from .chunking.fixed_size import FixedSizeChunker
|
||||
from .chunking.recursive import RecursiveCharacterChunker
|
||||
from .kb_helper import KBHelper
|
||||
|
||||
from .models import KnowledgeBase
|
||||
|
||||
|
||||
FILES_PATH = "data/knowledge_base"
|
||||
DB_PATH = Path(FILES_PATH) / "kb.db"
|
||||
"""Knowledge Base storage root directory"""
|
||||
CHUNKER = RecursiveCharacterChunker()
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
kb_db: KBSQLiteDatabase
|
||||
retrieval_manager: RetrievalManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
self.provider_manager = provider_manager
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
self.kb_insts: dict[str, KBHelper] = {}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化知识库模块"""
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 初始化数据库
|
||||
await self._init_kb_database()
|
||||
|
||||
# 初始化检索管理器
|
||||
sparse_retriever = SparseRetriever(self.kb_db)
|
||||
rank_fusion = RankFusion(self.kb_db)
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_db,
|
||||
)
|
||||
await self.load_kbs()
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"知识库模块导入失败: {e}")
|
||||
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||
except Exception as e:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
|
||||
|
||||
async def load_kbs(self):
|
||||
"""加载所有知识库实例"""
|
||||
kb_records = await self.kb_db.list_kbs()
|
||||
for record in kb_records:
|
||||
kb_helper = KBHelper(
|
||||
kb_db=self.kb_db,
|
||||
kb=record,
|
||||
provider_manager=self.provider_manager,
|
||||
kb_root_dir=FILES_PATH,
|
||||
chunker=CHUNKER,
|
||||
)
|
||||
await kb_helper.initialize()
|
||||
self.kb_insts[record.kb_id] = kb_helper
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
kb_name: str,
|
||||
description: str | None = None,
|
||||
emoji: str | None = None,
|
||||
embedding_provider_id: str | None = None,
|
||||
rerank_provider_id: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
top_k_dense: int | None = None,
|
||||
top_k_sparse: int | None = None,
|
||||
top_m_final: int | None = None,
|
||||
) -> KBHelper:
|
||||
"""创建新的知识库实例"""
|
||||
kb = KnowledgeBase(
|
||||
kb_name=kb_name,
|
||||
description=description,
|
||||
emoji=emoji or "📚",
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||
)
|
||||
async with self.kb_db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
|
||||
kb_helper = KBHelper(
|
||||
kb_db=self.kb_db,
|
||||
kb=kb,
|
||||
provider_manager=self.provider_manager,
|
||||
kb_root_dir=FILES_PATH,
|
||||
chunker=CHUNKER,
|
||||
)
|
||||
await kb_helper.initialize()
|
||||
self.kb_insts[kb.kb_id] = kb_helper
|
||||
return kb_helper
|
||||
|
||||
async def get_kb(self, kb_id: str) -> KBHelper | None:
|
||||
"""获取知识库实例"""
|
||||
if kb_id in self.kb_insts:
|
||||
return self.kb_insts[kb_id]
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
|
||||
"""通过名称获取知识库实例"""
|
||||
for kb_helper in self.kb_insts.values():
|
||||
if kb_helper.kb.kb_name == kb_name:
|
||||
return kb_helper
|
||||
return None
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除知识库实例"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return False
|
||||
|
||||
await kb_helper.delete_vec_db()
|
||||
async with self.kb_db.get_db() as session:
|
||||
await session.delete(kb_helper.kb)
|
||||
await session.commit()
|
||||
|
||||
self.kb_insts.pop(kb_id, None)
|
||||
return True
|
||||
|
||||
async def list_kbs(self) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库实例"""
|
||||
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
|
||||
return kbs
|
||||
|
||||
async def update_kb(
|
||||
self,
|
||||
kb_id: str,
|
||||
kb_name: str,
|
||||
description: str | None = None,
|
||||
emoji: str | None = None,
|
||||
embedding_provider_id: str | None = None,
|
||||
rerank_provider_id: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
top_k_dense: int | None = None,
|
||||
top_k_sparse: int | None = None,
|
||||
top_m_final: int | None = None,
|
||||
) -> KBHelper | None:
|
||||
"""更新知识库实例"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return None
|
||||
|
||||
kb = kb_helper.kb
|
||||
if kb_name is not None:
|
||||
kb.kb_name = kb_name
|
||||
if description is not None:
|
||||
kb.description = description
|
||||
if emoji is not None:
|
||||
kb.emoji = emoji
|
||||
if embedding_provider_id is not None:
|
||||
kb.embedding_provider_id = embedding_provider_id
|
||||
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
|
||||
if chunk_size is not None:
|
||||
kb.chunk_size = chunk_size
|
||||
if chunk_overlap is not None:
|
||||
kb.chunk_overlap = chunk_overlap
|
||||
if top_k_dense is not None:
|
||||
kb.top_k_dense = top_k_dense
|
||||
if top_k_sparse is not None:
|
||||
kb.top_k_sparse = top_k_sparse
|
||||
if top_m_final is not None:
|
||||
kb.top_m_final = top_m_final
|
||||
async with self.kb_db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
|
||||
return kb_helper
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_names: list[str],
|
||||
top_k_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
) -> dict | None:
|
||||
"""从指定知识库中检索相关内容"""
|
||||
kb_ids = []
|
||||
kb_id_helper_map = {}
|
||||
for kb_name in kb_names:
|
||||
if kb_helper := await self.get_kb_by_name(kb_name):
|
||||
kb_ids.append(kb_helper.kb.kb_id)
|
||||
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
|
||||
|
||||
if not kb_ids:
|
||||
return {}
|
||||
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
kb_id_helper_map=kb_id_helper_map,
|
||||
top_k_fusion=top_k_fusion,
|
||||
top_m_final=top_m_final,
|
||||
)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
context_text = self._format_context(results)
|
||||
|
||||
results_dict = [
|
||||
{
|
||||
"chunk_id": r.chunk_id,
|
||||
"doc_id": r.doc_id,
|
||||
"kb_id": r.kb_id,
|
||||
"kb_name": r.kb_name,
|
||||
"doc_name": r.doc_name,
|
||||
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
"char_count": r.metadata.get("char_count", 0),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"context_text": context_text,
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
def _format_context(self, results: list[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
lines.append(f"【知识 {i}】")
|
||||
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||
lines.append(f"内容: {result.content}")
|
||||
lines.append(f"相关度: {result.score:.2f}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def terminate(self):
|
||||
"""终止所有知识库实例,关闭数据库连接"""
|
||||
for kb_id, kb_helper in self.kb_insts.items():
|
||||
try:
|
||||
await kb_helper.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_insts.clear()
|
||||
|
||||
# 关闭元数据数据库
|
||||
if hasattr(self, "kb_db") and self.kb_db:
|
||||
try:
|
||||
await self.kb_db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭知识库元数据数据库失败: {e}")
|
||||
114
astrbot/core/knowledge_base/models.py
Normal file
114
astrbot/core/knowledge_base/models.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
|
||||
|
||||
|
||||
class BaseKBModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class KnowledgeBase(BaseKBModel, table=True):
|
||||
"""知识库表
|
||||
|
||||
存储知识库的基本信息和统计数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_bases" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
kb_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
kb_name: str = Field(max_length=100, nullable=False)
|
||||
description: str | None = Field(default=None, sa_type=Text)
|
||||
emoji: str | None = Field(default="📚", max_length=10)
|
||||
embedding_provider_id: str | None = Field(default=None, max_length=100)
|
||||
rerank_provider_id: str | None = Field(default=None, max_length=100)
|
||||
# 分块配置参数
|
||||
chunk_size: int | None = Field(default=512, nullable=True)
|
||||
chunk_overlap: int | None = Field(default=50, nullable=True)
|
||||
# 检索配置参数
|
||||
top_k_dense: int | None = Field(default=50, nullable=True)
|
||||
top_k_sparse: int | None = Field(default=50, nullable=True)
|
||||
top_m_final: int | None = Field(default=5, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
doc_count: int = Field(default=0, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"kb_name",
|
||||
name="uix_kb_name",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KBDocument(BaseKBModel, table=True):
|
||||
"""文档表
|
||||
|
||||
存储上传到知识库的文档元数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
doc_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
doc_name: str = Field(max_length=255, nullable=False)
|
||||
file_type: str = Field(max_length=20, nullable=False)
|
||||
file_size: int = Field(nullable=False)
|
||||
file_path: str = Field(max_length=512, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
media_count: int = Field(default=0, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class KBMedia(BaseKBModel, table=True):
|
||||
"""多媒体资源表
|
||||
|
||||
存储从文档中提取的图片、视频等多媒体资源。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_media" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
media_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
doc_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
media_type: str = Field(max_length=20, nullable=False)
|
||||
file_name: str = Field(max_length=255, nullable=False)
|
||||
file_path: str = Field(max_length=512, nullable=False)
|
||||
file_size: int = Field(nullable=False)
|
||||
mime_type: str = Field(max_length=100, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
15
astrbot/core/knowledge_base/parsers/__init__.py
Normal file
15
astrbot/core/knowledge_base/parsers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
文档解析器模块
|
||||
"""
|
||||
|
||||
from .base import BaseParser, MediaItem, ParseResult
|
||||
from .text_parser import TextParser
|
||||
from .pdf_parser import PDFParser
|
||||
|
||||
__all__ = [
|
||||
"BaseParser",
|
||||
"MediaItem",
|
||||
"ParseResult",
|
||||
"TextParser",
|
||||
"PDFParser",
|
||||
]
|
||||
50
astrbot/core/knowledge_base/parsers/base.py
Normal file
50
astrbot/core/knowledge_base/parsers/base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""文档解析器基类和数据结构
|
||||
|
||||
定义了文档解析器的抽象接口和相关数据类。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaItem:
|
||||
"""多媒体项
|
||||
|
||||
表示从文档中提取的多媒体资源。
|
||||
"""
|
||||
|
||||
media_type: str # image, video
|
||||
file_name: str
|
||||
content: bytes
|
||||
mime_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseResult:
|
||||
"""解析结果
|
||||
|
||||
包含解析后的文本内容和提取的多媒体资源。
|
||||
"""
|
||||
|
||||
text: str
|
||||
media: list[MediaItem]
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""文档解析器基类
|
||||
|
||||
所有文档解析器都应该继承此类并实现 parse 方法。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
"""解析文档
|
||||
|
||||
Args:
|
||||
file_content: 文件内容
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
25
astrbot/core/knowledge_base/parsers/markitdown_parser.py
Normal file
25
astrbot/core/knowledge_base/parsers/markitdown_parser.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import io
|
||||
import os
|
||||
|
||||
from astrbot.core.knowledge_base.parsers.base import (
|
||||
BaseParser,
|
||||
ParseResult,
|
||||
)
|
||||
from markitdown_no_magika import MarkItDown, StreamInfo
|
||||
|
||||
|
||||
class MarkitdownParser(BaseParser):
|
||||
"""解析 docx, xls, xlsx 格式"""
|
||||
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
bio = io.BytesIO(file_content)
|
||||
stream_info = StreamInfo(
|
||||
extension=os.path.splitext(file_name)[1].lower(),
|
||||
filename=file_name,
|
||||
)
|
||||
result = md.convert(bio, stream_info=stream_info)
|
||||
return ParseResult(
|
||||
text=result.markdown,
|
||||
media=[],
|
||||
)
|
||||
100
astrbot/core/knowledge_base/parsers/pdf_parser.py
Normal file
100
astrbot/core/knowledge_base/parsers/pdf_parser.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""PDF 文件解析器
|
||||
|
||||
支持解析 PDF 文件中的文本和图片资源。
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
from pypdf import PdfReader
|
||||
|
||||
from astrbot.core.knowledge_base.parsers.base import (
|
||||
BaseParser,
|
||||
MediaItem,
|
||||
ParseResult,
|
||||
)
|
||||
|
||||
|
||||
class PDFParser(BaseParser):
|
||||
"""PDF 文档解析器
|
||||
|
||||
提取 PDF 中的文本内容和嵌入的图片资源。
|
||||
"""
|
||||
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
"""解析 PDF 文件
|
||||
|
||||
Args:
|
||||
file_content: 文件内容
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
ParseResult: 包含文本和图片的解析结果
|
||||
"""
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
reader = PdfReader(pdf_file)
|
||||
|
||||
text_parts = []
|
||||
media_items = []
|
||||
|
||||
# 提取文本
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
# 提取图片
|
||||
image_counter = 0
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
try:
|
||||
# 安全检查 Resources
|
||||
if "/Resources" not in page:
|
||||
continue
|
||||
|
||||
resources = page["/Resources"]
|
||||
if not resources or "/XObject" not in resources: # type: ignore
|
||||
continue
|
||||
|
||||
xobjects = resources["/XObject"].get_object() # type: ignore
|
||||
if not xobjects:
|
||||
continue
|
||||
|
||||
for obj_name in xobjects:
|
||||
try:
|
||||
obj = xobjects[obj_name]
|
||||
|
||||
if obj.get("/Subtype") != "/Image":
|
||||
continue
|
||||
|
||||
# 提取图片数据
|
||||
image_data = obj.get_data()
|
||||
|
||||
# 确定格式
|
||||
filter_type = obj.get("/Filter", "")
|
||||
if filter_type == "/DCTDecode":
|
||||
ext = "jpg"
|
||||
mime_type = "image/jpeg"
|
||||
elif filter_type == "/FlateDecode":
|
||||
ext = "png"
|
||||
mime_type = "image/png"
|
||||
else:
|
||||
ext = "png"
|
||||
mime_type = "image/png"
|
||||
|
||||
image_counter += 1
|
||||
media_items.append(
|
||||
MediaItem(
|
||||
media_type="image",
|
||||
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
|
||||
content=image_data,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 单个图片提取失败不影响整体
|
||||
continue
|
||||
except Exception:
|
||||
# 页面处理失败不影响其他页面
|
||||
continue
|
||||
|
||||
full_text = "\n\n".join(text_parts)
|
||||
return ParseResult(text=full_text, media=media_items)
|
||||
41
astrbot/core/knowledge_base/parsers/text_parser.py
Normal file
41
astrbot/core/knowledge_base/parsers/text_parser.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""文本文件解析器
|
||||
|
||||
支持解析 TXT 和 Markdown 文件。
|
||||
"""
|
||||
|
||||
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
|
||||
|
||||
|
||||
class TextParser(BaseParser):
|
||||
"""TXT/MD 文本解析器
|
||||
|
||||
支持多种字符编码的自动检测。
|
||||
"""
|
||||
|
||||
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||
"""解析文本文件
|
||||
|
||||
尝试使用多种编码解析文件内容。
|
||||
|
||||
Args:
|
||||
file_content: 文件内容
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果,不包含多媒体资源
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法解码文件
|
||||
"""
|
||||
# 尝试多种编码
|
||||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
|
||||
try:
|
||||
text = file_content.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"无法解码文件: {file_name}")
|
||||
|
||||
# 文本文件无多媒体资源
|
||||
return ParseResult(text=text, media=[])
|
||||
13
astrbot/core/knowledge_base/parsers/util.py
Normal file
13
astrbot/core/knowledge_base/parsers/util.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .base import BaseParser
|
||||
|
||||
|
||||
async def select_parser(ext: str) -> BaseParser:
|
||||
if ext in {".md", ".txt", ".markdown", ".xlsx", ".docx", ".xls"}:
|
||||
from .markitdown_parser import MarkitdownParser
|
||||
|
||||
return MarkitdownParser()
|
||||
elif ext == ".pdf":
|
||||
from .pdf_parser import PDFParser
|
||||
|
||||
return PDFParser()
|
||||
raise ValueError(f"暂时不支持的文件格式: {ext}")
|
||||
16
astrbot/core/knowledge_base/retrieval/__init__.py
Normal file
16
astrbot/core/knowledge_base/retrieval/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
检索模块
|
||||
"""
|
||||
|
||||
from .manager import RetrievalManager, RetrievalResult
|
||||
from .sparse_retriever import SparseRetriever, SparseResult
|
||||
from .rank_fusion import RankFusion, FusedResult
|
||||
|
||||
__all__ = [
|
||||
"RetrievalManager",
|
||||
"RetrievalResult",
|
||||
"SparseRetriever",
|
||||
"SparseResult",
|
||||
"RankFusion",
|
||||
"FusedResult",
|
||||
]
|
||||
767
astrbot/core/knowledge_base/retrieval/hit_stopwords.txt
Normal file
767
astrbot/core/knowledge_base/retrieval/hit_stopwords.txt
Normal file
@@ -0,0 +1,767 @@
|
||||
———
|
||||
》),
|
||||
)÷(1-
|
||||
”,
|
||||
)、
|
||||
=(
|
||||
:
|
||||
→
|
||||
℃
|
||||
&
|
||||
*
|
||||
一一
|
||||
~~~~
|
||||
’
|
||||
.
|
||||
『
|
||||
.一
|
||||
./
|
||||
--
|
||||
』
|
||||
=″
|
||||
【
|
||||
[*]
|
||||
}>
|
||||
[⑤]]
|
||||
[①D]
|
||||
c]
|
||||
ng昉
|
||||
*
|
||||
//
|
||||
[
|
||||
]
|
||||
[②e]
|
||||
[②g]
|
||||
={
|
||||
}
|
||||
,也
|
||||
‘
|
||||
A
|
||||
[①⑥]
|
||||
[②B]
|
||||
[①a]
|
||||
[④a]
|
||||
[①③]
|
||||
[③h]
|
||||
③]
|
||||
1.
|
||||
--
|
||||
[②b]
|
||||
’‘
|
||||
×××
|
||||
[①⑧]
|
||||
0:2
|
||||
=[
|
||||
[⑤b]
|
||||
[②c]
|
||||
[④b]
|
||||
[②③]
|
||||
[③a]
|
||||
[④c]
|
||||
[①⑤]
|
||||
[①⑦]
|
||||
[①g]
|
||||
∈[
|
||||
[①⑨]
|
||||
[①④]
|
||||
[①c]
|
||||
[②f]
|
||||
[②⑧]
|
||||
[②①]
|
||||
[①C]
|
||||
[③c]
|
||||
[③g]
|
||||
[②⑤]
|
||||
[②②]
|
||||
一.
|
||||
[①h]
|
||||
.数
|
||||
[]
|
||||
[①B]
|
||||
数/
|
||||
[①i]
|
||||
[③e]
|
||||
[①①]
|
||||
[④d]
|
||||
[④e]
|
||||
[③b]
|
||||
[⑤a]
|
||||
[①A]
|
||||
[②⑧]
|
||||
[②⑦]
|
||||
[①d]
|
||||
[②j]
|
||||
〕〔
|
||||
][
|
||||
://
|
||||
′∈
|
||||
[②④
|
||||
[⑤e]
|
||||
12%
|
||||
b]
|
||||
...
|
||||
...................
|
||||
…………………………………………………③
|
||||
ZXFITL
|
||||
[③F]
|
||||
」
|
||||
[①o]
|
||||
]∧′=[
|
||||
∪φ∈
|
||||
′|
|
||||
{-
|
||||
②c
|
||||
}
|
||||
[③①]
|
||||
R.L.
|
||||
[①E]
|
||||
Ψ
|
||||
-[*]-
|
||||
↑
|
||||
.日
|
||||
[②d]
|
||||
[②
|
||||
[②⑦]
|
||||
[②②]
|
||||
[③e]
|
||||
[①i]
|
||||
[①B]
|
||||
[①h]
|
||||
[①d]
|
||||
[①g]
|
||||
[①②]
|
||||
[②a]
|
||||
f]
|
||||
[⑩]
|
||||
a]
|
||||
[①e]
|
||||
[②h]
|
||||
[②⑥]
|
||||
[③d]
|
||||
[②⑩]
|
||||
e]
|
||||
〉
|
||||
】
|
||||
元/吨
|
||||
[②⑩]
|
||||
2.3%
|
||||
5:0
|
||||
[①]
|
||||
::
|
||||
[②]
|
||||
[③]
|
||||
[④]
|
||||
[⑤]
|
||||
[⑥]
|
||||
[⑦]
|
||||
[⑧]
|
||||
[⑨]
|
||||
……
|
||||
——
|
||||
?
|
||||
、
|
||||
。
|
||||
“
|
||||
”
|
||||
《
|
||||
》
|
||||
!
|
||||
,
|
||||
:
|
||||
;
|
||||
?
|
||||
.
|
||||
,
|
||||
.
|
||||
'
|
||||
?
|
||||
·
|
||||
———
|
||||
──
|
||||
?
|
||||
—
|
||||
<
|
||||
>
|
||||
(
|
||||
)
|
||||
〔
|
||||
〕
|
||||
[
|
||||
]
|
||||
(
|
||||
)
|
||||
-
|
||||
+
|
||||
~
|
||||
×
|
||||
/
|
||||
/
|
||||
①
|
||||
②
|
||||
③
|
||||
④
|
||||
⑤
|
||||
⑥
|
||||
⑦
|
||||
⑧
|
||||
⑨
|
||||
⑩
|
||||
Ⅲ
|
||||
В
|
||||
"
|
||||
;
|
||||
#
|
||||
@
|
||||
γ
|
||||
μ
|
||||
φ
|
||||
φ.
|
||||
×
|
||||
Δ
|
||||
■
|
||||
▲
|
||||
sub
|
||||
exp
|
||||
sup
|
||||
sub
|
||||
Lex
|
||||
#
|
||||
%
|
||||
&
|
||||
'
|
||||
+
|
||||
+ξ
|
||||
++
|
||||
-
|
||||
-β
|
||||
<
|
||||
<±
|
||||
<Δ
|
||||
<λ
|
||||
<φ
|
||||
<<
|
||||
=
|
||||
=
|
||||
=☆
|
||||
=-
|
||||
>
|
||||
>λ
|
||||
_
|
||||
~±
|
||||
~+
|
||||
[⑤f]
|
||||
[⑤d]
|
||||
[②i]
|
||||
≈
|
||||
[②G]
|
||||
[①f]
|
||||
LI
|
||||
㈧
|
||||
[-
|
||||
......
|
||||
〉
|
||||
[③⑩]
|
||||
第二
|
||||
一番
|
||||
一直
|
||||
一个
|
||||
一些
|
||||
许多
|
||||
种
|
||||
有的是
|
||||
也就是说
|
||||
末##末
|
||||
啊
|
||||
阿
|
||||
哎
|
||||
哎呀
|
||||
哎哟
|
||||
唉
|
||||
俺
|
||||
俺们
|
||||
按
|
||||
按照
|
||||
吧
|
||||
吧哒
|
||||
把
|
||||
罢了
|
||||
被
|
||||
本
|
||||
本着
|
||||
比
|
||||
比方
|
||||
比如
|
||||
鄙人
|
||||
彼
|
||||
彼此
|
||||
边
|
||||
别
|
||||
别的
|
||||
别说
|
||||
并
|
||||
并且
|
||||
不比
|
||||
不成
|
||||
不单
|
||||
不但
|
||||
不独
|
||||
不管
|
||||
不光
|
||||
不过
|
||||
不仅
|
||||
不拘
|
||||
不论
|
||||
不怕
|
||||
不然
|
||||
不如
|
||||
不特
|
||||
不惟
|
||||
不问
|
||||
不只
|
||||
朝
|
||||
朝着
|
||||
趁
|
||||
趁着
|
||||
乘
|
||||
冲
|
||||
除
|
||||
除此之外
|
||||
除非
|
||||
除了
|
||||
此
|
||||
此间
|
||||
此外
|
||||
从
|
||||
从而
|
||||
打
|
||||
待
|
||||
但
|
||||
但是
|
||||
当
|
||||
当着
|
||||
到
|
||||
得
|
||||
的
|
||||
的话
|
||||
等
|
||||
等等
|
||||
地
|
||||
第
|
||||
叮咚
|
||||
对
|
||||
对于
|
||||
多
|
||||
多少
|
||||
而
|
||||
而况
|
||||
而且
|
||||
而是
|
||||
而外
|
||||
而言
|
||||
而已
|
||||
尔后
|
||||
反过来
|
||||
反过来说
|
||||
反之
|
||||
非但
|
||||
非徒
|
||||
否则
|
||||
嘎
|
||||
嘎登
|
||||
该
|
||||
赶
|
||||
个
|
||||
各
|
||||
各个
|
||||
各位
|
||||
各种
|
||||
各自
|
||||
给
|
||||
根据
|
||||
跟
|
||||
故
|
||||
故此
|
||||
固然
|
||||
关于
|
||||
管
|
||||
归
|
||||
果然
|
||||
果真
|
||||
过
|
||||
哈
|
||||
哈哈
|
||||
呵
|
||||
和
|
||||
何
|
||||
何处
|
||||
何况
|
||||
何时
|
||||
嘿
|
||||
哼
|
||||
哼唷
|
||||
呼哧
|
||||
乎
|
||||
哗
|
||||
还是
|
||||
还有
|
||||
换句话说
|
||||
换言之
|
||||
或
|
||||
或是
|
||||
或者
|
||||
极了
|
||||
及
|
||||
及其
|
||||
及至
|
||||
即
|
||||
即便
|
||||
即或
|
||||
即令
|
||||
即若
|
||||
即使
|
||||
几
|
||||
几时
|
||||
己
|
||||
既
|
||||
既然
|
||||
既是
|
||||
继而
|
||||
加之
|
||||
假如
|
||||
假若
|
||||
假使
|
||||
鉴于
|
||||
将
|
||||
较
|
||||
较之
|
||||
叫
|
||||
接着
|
||||
结果
|
||||
借
|
||||
紧接着
|
||||
进而
|
||||
尽
|
||||
尽管
|
||||
经
|
||||
经过
|
||||
就
|
||||
就是
|
||||
就是说
|
||||
据
|
||||
具体地说
|
||||
具体说来
|
||||
开始
|
||||
开外
|
||||
靠
|
||||
咳
|
||||
可
|
||||
可见
|
||||
可是
|
||||
可以
|
||||
况且
|
||||
啦
|
||||
来
|
||||
来着
|
||||
离
|
||||
例如
|
||||
哩
|
||||
连
|
||||
连同
|
||||
两者
|
||||
了
|
||||
临
|
||||
另
|
||||
另外
|
||||
另一方面
|
||||
论
|
||||
嘛
|
||||
吗
|
||||
慢说
|
||||
漫说
|
||||
冒
|
||||
么
|
||||
每
|
||||
每当
|
||||
们
|
||||
莫若
|
||||
某
|
||||
某个
|
||||
某些
|
||||
拿
|
||||
哪
|
||||
哪边
|
||||
哪儿
|
||||
哪个
|
||||
哪里
|
||||
哪年
|
||||
哪怕
|
||||
哪天
|
||||
哪些
|
||||
哪样
|
||||
那
|
||||
那边
|
||||
那儿
|
||||
那个
|
||||
那会儿
|
||||
那里
|
||||
那么
|
||||
那么些
|
||||
那么样
|
||||
那时
|
||||
那些
|
||||
那样
|
||||
乃
|
||||
乃至
|
||||
呢
|
||||
能
|
||||
你
|
||||
你们
|
||||
您
|
||||
宁
|
||||
宁可
|
||||
宁肯
|
||||
宁愿
|
||||
哦
|
||||
呕
|
||||
啪达
|
||||
旁人
|
||||
呸
|
||||
凭
|
||||
凭借
|
||||
其
|
||||
其次
|
||||
其二
|
||||
其他
|
||||
其它
|
||||
其一
|
||||
其余
|
||||
其中
|
||||
起
|
||||
起见
|
||||
起见
|
||||
岂但
|
||||
恰恰相反
|
||||
前后
|
||||
前者
|
||||
且
|
||||
然而
|
||||
然后
|
||||
然则
|
||||
让
|
||||
人家
|
||||
任
|
||||
任何
|
||||
任凭
|
||||
如
|
||||
如此
|
||||
如果
|
||||
如何
|
||||
如其
|
||||
如若
|
||||
如上所述
|
||||
若
|
||||
若非
|
||||
若是
|
||||
啥
|
||||
上下
|
||||
尚且
|
||||
设若
|
||||
设使
|
||||
甚而
|
||||
甚么
|
||||
甚至
|
||||
省得
|
||||
时候
|
||||
什么
|
||||
什么样
|
||||
使得
|
||||
是
|
||||
是的
|
||||
首先
|
||||
谁
|
||||
谁知
|
||||
顺
|
||||
顺着
|
||||
似的
|
||||
虽
|
||||
虽然
|
||||
虽说
|
||||
虽则
|
||||
随
|
||||
随着
|
||||
所
|
||||
所以
|
||||
他
|
||||
他们
|
||||
他人
|
||||
它
|
||||
它们
|
||||
她
|
||||
她们
|
||||
倘
|
||||
倘或
|
||||
倘然
|
||||
倘若
|
||||
倘使
|
||||
腾
|
||||
替
|
||||
通过
|
||||
同
|
||||
同时
|
||||
哇
|
||||
万一
|
||||
往
|
||||
望
|
||||
为
|
||||
为何
|
||||
为了
|
||||
为什么
|
||||
为着
|
||||
喂
|
||||
嗡嗡
|
||||
我
|
||||
我们
|
||||
呜
|
||||
呜呼
|
||||
乌乎
|
||||
无论
|
||||
无宁
|
||||
毋宁
|
||||
嘻
|
||||
吓
|
||||
相对而言
|
||||
像
|
||||
向
|
||||
向着
|
||||
嘘
|
||||
呀
|
||||
焉
|
||||
沿
|
||||
沿着
|
||||
要
|
||||
要不
|
||||
要不然
|
||||
要不是
|
||||
要么
|
||||
要是
|
||||
也
|
||||
也罢
|
||||
也好
|
||||
一
|
||||
一般
|
||||
一旦
|
||||
一方面
|
||||
一来
|
||||
一切
|
||||
一样
|
||||
一则
|
||||
依
|
||||
依照
|
||||
矣
|
||||
以
|
||||
以便
|
||||
以及
|
||||
以免
|
||||
以至
|
||||
以至于
|
||||
以致
|
||||
抑或
|
||||
因
|
||||
因此
|
||||
因而
|
||||
因为
|
||||
哟
|
||||
用
|
||||
由
|
||||
由此可见
|
||||
由于
|
||||
有
|
||||
有的
|
||||
有关
|
||||
有些
|
||||
又
|
||||
于
|
||||
于是
|
||||
于是乎
|
||||
与
|
||||
与此同时
|
||||
与否
|
||||
与其
|
||||
越是
|
||||
云云
|
||||
哉
|
||||
再说
|
||||
再者
|
||||
在
|
||||
在下
|
||||
咱
|
||||
咱们
|
||||
则
|
||||
怎
|
||||
怎么
|
||||
怎么办
|
||||
怎么样
|
||||
怎样
|
||||
咋
|
||||
照
|
||||
照着
|
||||
者
|
||||
这
|
||||
这边
|
||||
这儿
|
||||
这个
|
||||
这会儿
|
||||
这就是说
|
||||
这里
|
||||
这么
|
||||
这么点儿
|
||||
这么些
|
||||
这么样
|
||||
这时
|
||||
这些
|
||||
这样
|
||||
正如
|
||||
吱
|
||||
之
|
||||
之类
|
||||
之所以
|
||||
之一
|
||||
只是
|
||||
只限
|
||||
只要
|
||||
只有
|
||||
至
|
||||
至于
|
||||
诸位
|
||||
着
|
||||
着呢
|
||||
自
|
||||
自从
|
||||
自个儿
|
||||
自各儿
|
||||
自己
|
||||
自家
|
||||
自身
|
||||
综上所述
|
||||
总的来看
|
||||
总的来说
|
||||
总的说来
|
||||
总而言之
|
||||
总之
|
||||
纵
|
||||
纵令
|
||||
纵然
|
||||
纵使
|
||||
遵照
|
||||
作为
|
||||
兮
|
||||
呃
|
||||
呗
|
||||
咚
|
||||
咦
|
||||
喏
|
||||
啐
|
||||
喔唷
|
||||
嗬
|
||||
嗯
|
||||
嗳
|
||||
273
astrbot/core/knowledge_base/retrieval/manager.py
Normal file
273
astrbot/core/knowledge_base/retrieval/manager.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""检索管理器
|
||||
|
||||
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from ..kb_helper import KBHelper
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果"""
|
||||
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
kb_id: str
|
||||
kb_name: str
|
||||
content: str
|
||||
score: float
|
||||
metadata: dict
|
||||
|
||||
|
||||
class RetrievalManager:
|
||||
"""检索管理器
|
||||
|
||||
职责:
|
||||
- 协调稠密检索、稀疏检索和 Rerank
|
||||
- 结果融合和排序
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparse_retriever: SparseRetriever,
|
||||
rank_fusion: RankFusion,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
):
|
||||
"""初始化检索管理器
|
||||
|
||||
Args:
|
||||
vec_db_factory: 向量数据库工厂
|
||||
sparse_retriever: 稀疏检索器
|
||||
rank_fusion: 结果融合器
|
||||
kb_db: 知识库数据库实例
|
||||
"""
|
||||
self.sparse_retriever = sparse_retriever
|
||||
self.rank_fusion = rank_fusion
|
||||
self.kb_db = kb_db
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
kb_id_helper_map: dict[str, KBHelper],
|
||||
top_k_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
) -> List[RetrievalResult]:
|
||||
"""混合检索
|
||||
|
||||
流程:
|
||||
1. 稠密检索 (向量相似度)
|
||||
2. 稀疏检索 (BM25)
|
||||
3. 结果融合 (RRF)
|
||||
4. Rerank 重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_m_final: 最终返回数量
|
||||
enable_rerank: 是否启用 Rerank
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 检索结果列表
|
||||
"""
|
||||
if not kb_ids:
|
||||
return []
|
||||
|
||||
kb_options: dict = {}
|
||||
new_kb_ids = []
|
||||
for kb_id in kb_ids:
|
||||
kb_helper = kb_id_helper_map.get(kb_id)
|
||||
if kb_helper:
|
||||
kb = kb_helper.kb
|
||||
kb_options[kb_id] = {
|
||||
"top_k_dense": kb.top_k_dense or 50,
|
||||
"top_k_sparse": kb.top_k_sparse or 50,
|
||||
"top_m_final": kb.top_m_final or 5,
|
||||
"vec_db": kb_helper.vec_db,
|
||||
"rerank_provider_id": kb.rerank_provider_id,
|
||||
}
|
||||
new_kb_ids.append(kb_id)
|
||||
else:
|
||||
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
|
||||
|
||||
kb_ids = new_kb_ids
|
||||
|
||||
# 1. 稠密检索
|
||||
time_start = time.time()
|
||||
dense_results = await self._dense_retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
kb_options=kb_options,
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
|
||||
)
|
||||
|
||||
# 2. 稀疏检索
|
||||
time_start = time.time()
|
||||
sparse_results = await self.sparse_retriever.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
kb_options=kb_options,
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
|
||||
)
|
||||
|
||||
# 3. 结果融合
|
||||
time_start = time.time()
|
||||
fused_results = await self.rank_fusion.fuse(
|
||||
dense_results=dense_results,
|
||||
sparse_results=sparse_results,
|
||||
top_k=top_k_fusion,
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
|
||||
)
|
||||
|
||||
# 4. 转换为 RetrievalResult (获取元数据)
|
||||
retrieval_results = []
|
||||
for fr in fused_results:
|
||||
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
|
||||
if metadata_dict:
|
||||
retrieval_results.append(
|
||||
RetrievalResult(
|
||||
chunk_id=fr.chunk_id,
|
||||
doc_id=fr.doc_id,
|
||||
doc_name=metadata_dict["document"].doc_name,
|
||||
kb_id=fr.kb_id,
|
||||
kb_name=metadata_dict["knowledge_base"].kb_name,
|
||||
content=fr.content,
|
||||
score=fr.score,
|
||||
metadata={
|
||||
"chunk_index": fr.chunk_index,
|
||||
"char_count": len(fr.content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
and vec_db.rerank_provider
|
||||
and rerank_pi
|
||||
and rerank_pi == vec_db.rerank_provider.meta().id
|
||||
):
|
||||
first_rerank = vec_db.rerank_provider
|
||||
break
|
||||
if first_rerank and retrieval_results:
|
||||
retrieval_results = await self._rerank(
|
||||
query=query,
|
||||
results=retrieval_results,
|
||||
top_k=top_m_final,
|
||||
rerank_provider=first_rerank,
|
||||
)
|
||||
|
||||
return retrieval_results[:top_m_final]
|
||||
|
||||
async def _dense_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
kb_options: dict,
|
||||
):
|
||||
"""稠密检索 (向量相似度)
|
||||
|
||||
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[Result]: 检索结果列表
|
||||
"""
|
||||
all_results: list[Result] = []
|
||||
for kb_id in kb_ids:
|
||||
if kb_id not in kb_options:
|
||||
continue
|
||||
try:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
dense_k = int(kb_options[kb_id]["top_k_dense"])
|
||||
vec_results = await vec_db.retrieve(
|
||||
query=query,
|
||||
k=dense_k,
|
||||
fetch_k=dense_k * 2,
|
||||
rerank=False, # 稠密检索阶段不进行 rerank
|
||||
metadata_filters={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
all_results.extend(vec_results)
|
||||
except Exception as e:
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
|
||||
continue
|
||||
|
||||
# 按相似度排序并返回 top_k
|
||||
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
||||
# return all_results[: len(all_results) // len(kb_ids)]
|
||||
return all_results
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
query: str,
|
||||
results: List[RetrievalResult],
|
||||
top_k: int,
|
||||
rerank_provider: RerankProvider,
|
||||
) -> List[RetrievalResult]:
|
||||
"""Rerank 重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
results: 检索结果列表
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 重排序后的结果列表
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# 准备文档列表
|
||||
docs = [r.content for r in results]
|
||||
|
||||
# 调用 Rerank Provider
|
||||
rerank_results = await rerank_provider.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
)
|
||||
|
||||
# 更新分数并重新排序
|
||||
reranked_list = []
|
||||
for rerank_result in rerank_results:
|
||||
idx = rerank_result.index
|
||||
if idx < len(results):
|
||||
result = results[idx]
|
||||
result.score = rerank_result.relevance_score
|
||||
reranked_list.append(result)
|
||||
|
||||
reranked_list.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return reranked_list[:top_k]
|
||||
138
astrbot/core/knowledge_base/retrieval/rank_fusion.py
Normal file
138
astrbot/core/knowledge_base/retrieval/rank_fusion.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""检索结果融合器
|
||||
|
||||
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedResult:
|
||||
"""融合后的检索结果"""
|
||||
|
||||
chunk_id: str
|
||||
chunk_index: int
|
||||
doc_id: str
|
||||
kb_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
|
||||
class RankFusion:
|
||||
"""检索结果融合器
|
||||
|
||||
职责:
|
||||
- 融合稠密检索和稀疏检索的结果
|
||||
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
|
||||
"""初始化结果融合器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
k: RRF 参数,用于平滑排名
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.k = k
|
||||
|
||||
async def fuse(
|
||||
self,
|
||||
dense_results: list[Result],
|
||||
sparse_results: list[SparseResult],
|
||||
top_k: int = 20,
|
||||
) -> list[FusedResult]:
|
||||
"""融合稠密和稀疏检索结果
|
||||
|
||||
RRF 公式:
|
||||
score(doc) = sum(1 / (k + rank_i))
|
||||
|
||||
Args:
|
||||
dense_results: 稠密检索结果
|
||||
sparse_results: 稀疏检索结果
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
List[FusedResult]: 融合后的结果列表
|
||||
"""
|
||||
# 1. 构建排名映射
|
||||
dense_ranks = {
|
||||
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
|
||||
} # 这里的 doc_id 实际上是 chunk_id
|
||||
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
|
||||
|
||||
# 2. 收集所有唯一的 ID
|
||||
# 需要统一为 chunk_id
|
||||
all_chunk_ids = set()
|
||||
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
|
||||
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
|
||||
|
||||
# 处理稀疏检索结果
|
||||
for r in sparse_results:
|
||||
all_chunk_ids.add(r.chunk_id)
|
||||
chunk_id_to_sparse[r.chunk_id] = r
|
||||
|
||||
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
|
||||
for r in dense_results:
|
||||
vec_doc_id = r.data["doc_id"]
|
||||
all_chunk_ids.add(vec_doc_id)
|
||||
vec_doc_id_to_dense[vec_doc_id] = r
|
||||
|
||||
# 3. 计算 RRF 分数
|
||||
rrf_scores: dict[str, float] = {}
|
||||
|
||||
for identifier in all_chunk_ids:
|
||||
score = 0.0
|
||||
|
||||
# 来自稠密检索的贡献
|
||||
if identifier in dense_ranks:
|
||||
score += 1.0 / (self.k + dense_ranks[identifier])
|
||||
|
||||
# 来自稀疏检索的贡献
|
||||
if identifier in sparse_ranks:
|
||||
score += 1.0 / (self.k + sparse_ranks[identifier])
|
||||
|
||||
rrf_scores[identifier] = score
|
||||
|
||||
# 4. 排序
|
||||
sorted_ids = sorted(
|
||||
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
|
||||
)[:top_k]
|
||||
|
||||
# 5. 构建融合结果
|
||||
fused_results = []
|
||||
for identifier in sorted_ids:
|
||||
# 优先从稀疏检索获取完整信息
|
||||
if identifier in chunk_id_to_sparse:
|
||||
sr = chunk_id_to_sparse[identifier]
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=sr.chunk_id,
|
||||
chunk_index=sr.chunk_index,
|
||||
doc_id=sr.doc_id,
|
||||
kb_id=sr.kb_id,
|
||||
content=sr.content,
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
)
|
||||
elif identifier in vec_doc_id_to_dense:
|
||||
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
||||
vec_result = vec_doc_id_to_dense[identifier]
|
||||
chunk_md = json.loads(vec_result.data["metadata"])
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=identifier,
|
||||
chunk_index=chunk_md["chunk_index"],
|
||||
doc_id=chunk_md["kb_doc_id"],
|
||||
kb_id=chunk_md["kb_id"],
|
||||
content=vec_result.data["text"],
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
)
|
||||
|
||||
return fused_results
|
||||
130
astrbot/core/knowledge_base/retrieval/sparse_retriever.py
Normal file
130
astrbot/core/knowledge_base/retrieval/sparse_retriever.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""稀疏检索器
|
||||
|
||||
使用 BM25 算法进行基于关键词的文档检索
|
||||
"""
|
||||
|
||||
import jieba
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from rank_bm25 import BM25Okapi
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseResult:
|
||||
"""稀疏检索结果"""
|
||||
|
||||
chunk_index: int
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
kb_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
|
||||
class SparseRetriever:
|
||||
"""BM25 稀疏检索器
|
||||
|
||||
职责:
|
||||
- 基于关键词的文档检索
|
||||
- 使用 BM25 算法计算相关度
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
"""初始化稀疏检索器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self._index_cache = {} # 缓存 BM25 索引
|
||||
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
self.hit_stopwords = {
|
||||
word.strip() for word in set(f.read().splitlines()) if word.strip()
|
||||
}
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: list[str],
|
||||
kb_options: dict,
|
||||
) -> list[SparseResult]:
|
||||
"""执行稀疏检索
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
kb_options: 每个知识库的检索选项
|
||||
|
||||
Returns:
|
||||
List[SparseResult]: 检索结果列表
|
||||
"""
|
||||
# 1. 获取所有相关块
|
||||
top_k_sparse = 0
|
||||
chunks = []
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
|
||||
if not vec_db:
|
||||
continue
|
||||
result = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={}, limit=None, offset=None
|
||||
)
|
||||
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
|
||||
result = [
|
||||
{
|
||||
"chunk_id": doc["doc_id"],
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"doc_id": chunk_md["kb_doc_id"],
|
||||
"kb_id": kb_id,
|
||||
"text": doc["text"],
|
||||
}
|
||||
for doc, chunk_md in zip(result, chunk_mds)
|
||||
]
|
||||
chunks.extend(result)
|
||||
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# 2. 准备文档和索引
|
||||
corpus = [chunk["text"] for chunk in chunks]
|
||||
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
|
||||
tokenized_corpus = [
|
||||
[word for word in doc if word not in self.hit_stopwords]
|
||||
for doc in tokenized_corpus
|
||||
]
|
||||
|
||||
# 3. 构建 BM25 索引
|
||||
bm25 = BM25Okapi(tokenized_corpus)
|
||||
|
||||
# 4. 执行检索
|
||||
tokenized_query = list(jieba.cut(query))
|
||||
tokenized_query = [
|
||||
word for word in tokenized_query if word not in self.hit_stopwords
|
||||
]
|
||||
scores = bm25.get_scores(tokenized_query)
|
||||
|
||||
# 5. 排序并返回 Top-K
|
||||
results = []
|
||||
for idx, score in enumerate(scores):
|
||||
chunk = chunks[idx]
|
||||
results.append(
|
||||
SparseResult(
|
||||
chunk_id=chunk["chunk_id"],
|
||||
chunk_index=chunk["chunk_index"],
|
||||
doc_id=chunk["doc_id"],
|
||||
kb_id=chunk["kb_id"],
|
||||
content=chunk["text"],
|
||||
score=float(score),
|
||||
)
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
# return results[: len(results) // len(kb_ids)]
|
||||
return results[:top_k_sparse]
|
||||
@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self.strategy_selector = StrategySelector(config)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, check_text: str = None
|
||||
self, event: AstrMessageEvent, check_text: str | None = None
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
||||
self.secret_key = sk
|
||||
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
||||
|
||||
def check(self, content: str):
|
||||
def check(self, content: str) -> tuple[bool, str]:
|
||||
res = self.client.textCensorUserDefined(content)
|
||||
if "conclusionType" not in res:
|
||||
return False, ""
|
||||
|
||||
@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
|
||||
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||
# )
|
||||
|
||||
def check(self, content: str) -> bool:
|
||||
def check(self, content: str) -> tuple[bool, str]:
|
||||
for keyword in self.keywords:
|
||||
if re.search(keyword, content):
|
||||
return False, "内容安全检查不通过,匹配到敏感词。"
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -36,6 +36,9 @@ async def call_handler(
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
@@ -94,5 +97,6 @@ async def call_event_hook(
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return True
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
@@ -22,6 +23,26 @@ class PreProcessStage(Stage):
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""在处理事件之前的预处理"""
|
||||
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
||||
supported = {"telegram", "lark"}
|
||||
platform = event.get_platform_name()
|
||||
cfg = (
|
||||
self.config.get("platform_specific", {})
|
||||
.get(platform, {})
|
||||
.get("pre_ack_emoji", {})
|
||||
) or {}
|
||||
emojis = cfg.get("emojis") or []
|
||||
if (
|
||||
cfg.get("enable", False)
|
||||
and platform in supported
|
||||
and emojis
|
||||
and event.is_at_or_wake_command
|
||||
):
|
||||
try:
|
||||
await event.react(random.choice(emojis))
|
||||
except Exception as e:
|
||||
logger.warning(f"{platform} 预回应表情发送失败: {e}")
|
||||
|
||||
# 路径映射
|
||||
if mappings := self.platform_settings.get("path_mapping", []):
|
||||
# 支持 Record,Image 消息段的路径映射。
|
||||
@@ -46,6 +67,9 @@ class PreProcessStage(Stage):
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
|
||||
)
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
|
||||
@@ -6,7 +6,9 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
from datetime import timedelta
|
||||
from collections.abc import AsyncGenerator
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
@@ -31,6 +33,7 @@ from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext, call_event_hook, call_handler
|
||||
from ..stage import Stage
|
||||
from ..utils import inject_kb_context
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
@@ -42,7 +45,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@@ -100,7 +103,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description,
|
||||
system_prompt=tool.description or "",
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
@@ -133,6 +136,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
|
||||
if not llm_response:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||
)
|
||||
@@ -148,7 +160,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
yield mcp.types.TextContent(
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
@@ -175,21 +187,33 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
handler=awaitable,
|
||||
**tool_args,
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
# async for resp in wrapper:
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
@@ -200,16 +224,23 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
):
|
||||
if not tool.mcp_client:
|
||||
raise ValueError("MCP client is not available for MCP function tools.")
|
||||
res = await tool.mcp_client.session.call_tool(
|
||||
|
||||
session = tool.mcp_client.session
|
||||
if not session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
res = await session.call_tool(
|
||||
name=tool.name,
|
||||
arguments=tool_args,
|
||||
read_timeout_seconds=timedelta(
|
||||
seconds=run_context.context.tool_call_timeout
|
||||
),
|
||||
)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
@@ -271,19 +302,12 @@ async def run_agent(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
astr_event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -300,6 +324,7 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.max_step: int = settings.get("max_agent_step", 30)
|
||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
@@ -313,7 +338,7 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
"""选择使用的 LLM 提供商"""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
_ctx = self.ctx.plugin_manager.context
|
||||
@@ -325,7 +350,7 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent):
|
||||
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
@@ -337,11 +362,13 @@ class LLMRequestSubStage(Stage):
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, _nested: bool = False
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
@@ -356,6 +383,9 @@ class LLMRequestSubStage(Stage):
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
@@ -390,6 +420,14 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 应用知识库
|
||||
try:
|
||||
await inject_kb_context(
|
||||
umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用知识库时遇到问题: {e}")
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
@@ -444,13 +482,19 @@ class LLMRequestSubStage(Stage):
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
plugin = star_map.get(tool.handler_module_path)
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
plugin = star_map.get(mp)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
@@ -461,6 +505,7 @@ class LLMRequestSubStage(Stage):
|
||||
first_provider_request=req,
|
||||
curr_provider_request=req,
|
||||
streaming=self.streaming_response,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
@@ -487,8 +532,10 @@ class LLMRequestSubStage(Stage):
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
else:
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
@@ -499,16 +546,29 @@ class LLMRequestSubStage(Stage):
|
||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_webchat(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
if not req.conversation:
|
||||
return
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
@@ -517,7 +577,23 @@ class LLMRequestSubStage(Stage):
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
content = latest_pair[0].get("content", "")
|
||||
if isinstance(content, list):
|
||||
# 多模态
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
text_parts.append("[图片]")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
cleaned_text = "User: " + " ".join(text_parts).strip()
|
||||
elif isinstance(content, str):
|
||||
cleaned_text = "User: " + content.strip()
|
||||
else:
|
||||
return
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
|
||||
@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
|
||||
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
continue
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
md = star_map.get(handler.handler_module_path)
|
||||
if not md:
|
||||
logger.warning(
|
||||
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
||||
)
|
||||
continue
|
||||
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
||||
try:
|
||||
wrapper = call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
80
astrbot/core/pipeline/process_stage/utils.py
Normal file
80
astrbot/core/pipeline/process_stage/utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.api import logger, sp
|
||||
|
||||
|
||||
async def inject_kb_context(
|
||||
umo: str,
|
||||
p_ctx: PipelineContext,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
"""inject knowledge base context into the provider request
|
||||
|
||||
Args:
|
||||
umo: Unique message object (session ID)
|
||||
p_ctx: Pipeline context
|
||||
req: Provider request
|
||||
"""
|
||||
|
||||
kb_mgr = p_ctx.plugin_manager.context.kb_manager
|
||||
|
||||
# 1. 优先读取会话级配置
|
||||
session_config = await sp.session_get(umo, "kb_config", default={})
|
||||
|
||||
if session_config and "kb_ids" in session_config:
|
||||
# 会话级配置
|
||||
kb_ids = session_config.get("kb_ids", [])
|
||||
|
||||
# 如果配置为空列表,明确表示不使用知识库
|
||||
if not kb_ids:
|
||||
logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
|
||||
return
|
||||
|
||||
top_k = session_config.get("top_k", 5)
|
||||
|
||||
# 将 kb_ids 转换为 kb_names
|
||||
kb_names = []
|
||||
invalid_kb_ids = []
|
||||
for kb_id in kb_ids:
|
||||
kb_helper = await kb_mgr.get_kb(kb_id)
|
||||
if kb_helper:
|
||||
kb_names.append(kb_helper.kb.kb_name)
|
||||
else:
|
||||
logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
|
||||
invalid_kb_ids.append(kb_id)
|
||||
|
||||
if invalid_kb_ids:
|
||||
logger.warning(
|
||||
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
|
||||
)
|
||||
|
||||
if not kb_names:
|
||||
return
|
||||
|
||||
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
|
||||
else:
|
||||
kb_names = p_ctx.astrbot_config.get("kb_names", [])
|
||||
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
|
||||
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
|
||||
|
||||
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
|
||||
|
||||
if not kb_names:
|
||||
return
|
||||
|
||||
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
|
||||
kb_context = await kb_mgr.retrieve(
|
||||
query=req.prompt,
|
||||
kb_names=kb_names,
|
||||
top_k_fusion=top_k_fusion,
|
||||
top_m_final=top_k,
|
||||
)
|
||||
|
||||
if not kb_context:
|
||||
return
|
||||
|
||||
formatted = kb_context.get("context_text", "")
|
||||
if formatted:
|
||||
results = kb_context.get("results", [])
|
||||
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
|
||||
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
|
||||
@@ -1,17 +1,15 @@
|
||||
import random
|
||||
import asyncio
|
||||
import math
|
||||
import traceback
|
||||
import astrbot.core.message.components as Comp
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from ..context import PipelineContext, call_event_hook
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -114,6 +112,43 @@ class RespondStage(Stage):
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
|
||||
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
|
||||
"""检查是否需要分段回复"""
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
"qq_official",
|
||||
"weixin_official_account",
|
||||
"dingtalk",
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _extract_comp(
|
||||
self,
|
||||
raw_chain: list[BaseMessageComponent],
|
||||
extract_types: set[ComponentType],
|
||||
modify_raw_chain: bool = True,
|
||||
):
|
||||
extracted = []
|
||||
if modify_raw_chain:
|
||||
remaining = []
|
||||
for comp in raw_chain:
|
||||
if comp.type in extract_types:
|
||||
extracted.append(comp)
|
||||
else:
|
||||
remaining.append(comp)
|
||||
raw_chain[:] = remaining
|
||||
else:
|
||||
extracted = [comp for comp in raw_chain if comp.type in extract_types]
|
||||
|
||||
return extracted
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -123,7 +158,14 @@ class RespondStage(Stage):
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
)
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
if result.async_stream is None:
|
||||
logger.warning("async_stream 为空,跳过发送。")
|
||||
return
|
||||
# 流式结果直接交付平台适配器处理
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
@@ -148,87 +190,81 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||
non_record_comps = [
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
# 将 Plain 为空的消息段移除
|
||||
result.chain = [
|
||||
comp
|
||||
for comp in result.chain
|
||||
if not (
|
||||
isinstance(comp, Comp.Plain)
|
||||
and (not comp.text or not comp.text.strip())
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
self.enable_seg
|
||||
and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
# 发送消息链
|
||||
# Record 需要强制单独发送
|
||||
need_separately = {ComponentType.Record}
|
||||
if self.is_seg_reply_required(event):
|
||||
header_comps = self._extract_comp(
|
||||
result.chain,
|
||||
{ComponentType.Reply, ComponentType.At},
|
||||
modify_raw_chain=True,
|
||||
)
|
||||
and event.get_platform_name()
|
||||
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
||||
):
|
||||
decorated_comps = []
|
||||
if self.reply_with_mention:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Comp.At):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
if self.reply_with_quote:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Comp.Reply):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
if not result.chain or len(result.chain) == 0:
|
||||
# may fix #2670
|
||||
logger.warning(
|
||||
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
|
||||
)
|
||||
return
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
if comp.type in need_separately:
|
||||
await event.send(MessageChain([comp]))
|
||||
else:
|
||||
await event.send(MessageChain([*header_comps, comp]))
|
||||
header_comps.clear()
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
if all(
|
||||
comp.type in {ComponentType.Reply, ComponentType.At}
|
||||
for comp in result.chain
|
||||
):
|
||||
# may fix #2670
|
||||
logger.warning(
|
||||
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
|
||||
)
|
||||
return
|
||||
sep_comps = self._extract_comp(
|
||||
result.chain,
|
||||
need_separately,
|
||||
modify_raw_chain=True,
|
||||
)
|
||||
for comp in sep_comps:
|
||||
chain = MessageChain([comp])
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
await event.send(chain)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chain = MessageChain(result.chain)
|
||||
if result.chain and len(result.chain) > 0:
|
||||
try:
|
||||
await event.send(chain)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
await event.send(MessageChain(non_record_comps))
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
|
||||
return
|
||||
|
||||
event.clear_result()
|
||||
|
||||
@@ -183,56 +183,60 @@ class ResultDecorateStage(Stage):
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
@@ -275,7 +279,6 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
# 触发转发消息
|
||||
has_forwarded = False
|
||||
if event.get_platform_name() == "aiocqhttp":
|
||||
word_cnt = 0
|
||||
for comp in result.chain:
|
||||
@@ -286,9 +289,9 @@ class ResultDecorateStage(Stage):
|
||||
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
|
||||
)
|
||||
result.chain = [node]
|
||||
has_forwarded = True
|
||||
|
||||
if not has_forwarded:
|
||||
has_plain = any(isinstance(item, Plain) for item in result.chain)
|
||||
if has_plain:
|
||||
# at 回复
|
||||
if (
|
||||
self.reply_with_mention
|
||||
|
||||
@@ -74,7 +74,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if event.get_platform_name() == "webchat":
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
self.ctx = ctx
|
||||
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
conv_id = await self.conv_mgr.get_curr_conversation_id(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if not conv_id:
|
||||
await self.conv_mgr.new_conversation(
|
||||
event.unified_msg_origin, platform_id=event.get_platform_id()
|
||||
)
|
||||
|
||||
event.stop_event()
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra():
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
is_group_cmd_handler = any(
|
||||
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
||||
)
|
||||
if not is_group_cmd_handler:
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra(default={}):
|
||||
handlers_parsed_params[handler.handler_full_name] = (
|
||||
event.get_extra("parsed_params")
|
||||
)
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
from typing import List, Union, Optional, AsyncGenerator, Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
@@ -49,7 +49,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否唤醒(是否通过 WakingStage)"""
|
||||
self.is_at_or_wake_command = False
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self._extras: dict[str, Any] = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
@@ -57,7 +57,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||
self._result: MessageEventResult = None
|
||||
self._result: MessageEventResult | None = None
|
||||
"""消息事件的结果"""
|
||||
|
||||
self._has_send_oper = False
|
||||
@@ -90,8 +90,10 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.message_str
|
||||
|
||||
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
|
||||
def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str:
|
||||
outline = ""
|
||||
if not chain:
|
||||
return outline
|
||||
for i in chain:
|
||||
if isinstance(i, Plain):
|
||||
outline += i.text
|
||||
@@ -173,13 +175,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(self, key=None):
|
||||
def get_extra(self, key: str | None = None, default=None) -> Any:
|
||||
"""
|
||||
获取额外的信息。
|
||||
"""
|
||||
if key is None:
|
||||
return self._extras
|
||||
return self._extras.get(key, None)
|
||||
return self._extras.get(key, default)
|
||||
|
||||
def clear_extra(self):
|
||||
"""
|
||||
@@ -261,6 +263,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
result = MessageEventResult().message(result)
|
||||
# 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表
|
||||
if isinstance(result, MessageEventResult) and result.chain is None:
|
||||
result.chain = []
|
||||
self._result = result
|
||||
|
||||
def stop_event(self):
|
||||
@@ -412,6 +417,16 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def react(self, emoji: str):
|
||||
"""
|
||||
对消息添加表情回应。
|
||||
|
||||
默认实现为发送一条包含该表情的消息。
|
||||
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
|
||||
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
|
||||
"""
|
||||
await self.send(MessageChain([Plain(emoji)]))
|
||||
|
||||
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group_id: str = "" # 群组id,如果为私聊,则为空
|
||||
group: Group # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -64,6 +64,28 @@ class AstrBotMessage:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.timestamp = int(time.time())
|
||||
self.group = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
@property
|
||||
def group_id(self) -> str:
|
||||
"""
|
||||
向后兼容的 group_id 属性
|
||||
群组id,如果为私聊,则为空
|
||||
"""
|
||||
if self.group:
|
||||
return self.group.group_id
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
self.group.group_id = value
|
||||
else:
|
||||
self.group = Group(group_id=value)
|
||||
else:
|
||||
self.group = None
|
||||
|
||||
@@ -82,6 +82,10 @@ class PlatformManager:
|
||||
from .sources.wecom.wecom_adapter import (
|
||||
WecomPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wecom_ai_bot":
|
||||
from .sources.wecom_ai_bot.wecomai_adapter import (
|
||||
WecomAIBotAdapter, # noqa: F401
|
||||
)
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
||||
@@ -90,6 +94,10 @@ class PlatformManager:
|
||||
from .sources.discord.discord_platform_adapter import (
|
||||
DiscordPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "misskey":
|
||||
from .sources.misskey.misskey_adapter import (
|
||||
MisskeyPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "slack":
|
||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||
case "satori":
|
||||
|
||||
@@ -14,3 +14,5 @@ class PlatformMetadata:
|
||||
"""平台的默认配置模板"""
|
||||
adapter_display_name: str = None
|
||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||
logo_path: str = None
|
||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
||||
|
||||
@@ -13,10 +13,12 @@ def register_platform_adapter(
|
||||
desc: str,
|
||||
default_config_tmpl: dict = None,
|
||||
adapter_display_name: str = None,
|
||||
logo_path: str = None,
|
||||
):
|
||||
"""用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
@@ -39,6 +41,7 @@ def register_platform_adapter(
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
str(event.sender["user_id"]), event.sender["nickname"]
|
||||
str(event.sender["user_id"]),
|
||||
event.sender.get("card") or event.sender.get("nickname", "N/A"),
|
||||
)
|
||||
if event["message_type"] == "group":
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -107,6 +107,22 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji).build())
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||
return None
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
|
||||
727
astrbot/core/platform/sources/misskey/misskey_adapter.py
Normal file
727
astrbot/core/platform/sources/misskey/misskey_adapter.py
Normal file
@@ -0,0 +1,727 @@
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Dict, Any, Optional, Awaitable, List
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.platform import (
|
||||
AstrBotMessage,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from .misskey_api import MisskeyAPI
|
||||
import os
|
||||
|
||||
try:
|
||||
import magic # type: ignore
|
||||
except Exception:
|
||||
magic = None
|
||||
|
||||
from .misskey_event import MisskeyPlatformEvent
|
||||
from .misskey_utils import (
|
||||
serialize_message_chain,
|
||||
resolve_message_visibility,
|
||||
is_valid_user_session_id,
|
||||
is_valid_room_session_id,
|
||||
add_at_mention_if_needed,
|
||||
process_files,
|
||||
extract_sender_info,
|
||||
create_base_message,
|
||||
process_at_mention,
|
||||
format_poll,
|
||||
cache_user_info,
|
||||
cache_room_info,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# Constants
|
||||
MAX_FILE_UPLOAD_COUNT = 16
|
||||
DEFAULT_UPLOAD_CONCURRENCY = 3
|
||||
|
||||
|
||||
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||
class MisskeyPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config or {}
|
||||
self.settings = platform_settings or {}
|
||||
self.instance_url = self.config.get("misskey_instance_url", "")
|
||||
self.access_token = self.config.get("misskey_token", "")
|
||||
self.max_message_length = self.config.get("max_message_length", 3000)
|
||||
self.default_visibility = self.config.get(
|
||||
"misskey_default_visibility", "public"
|
||||
)
|
||||
self.local_only = self.config.get("misskey_local_only", False)
|
||||
self.enable_chat = self.config.get("misskey_enable_chat", True)
|
||||
self.enable_file_upload = self.config.get("misskey_enable_file_upload", True)
|
||||
self.upload_folder = self.config.get("misskey_upload_folder")
|
||||
|
||||
# download / security related options (exposed to platform_config)
|
||||
self.allow_insecure_downloads = bool(
|
||||
self.config.get("misskey_allow_insecure_downloads", False)
|
||||
)
|
||||
# parse download timeout and chunk size safely
|
||||
_dt = self.config.get("misskey_download_timeout")
|
||||
try:
|
||||
self.download_timeout = int(_dt) if _dt is not None else 15
|
||||
except Exception:
|
||||
self.download_timeout = 15
|
||||
|
||||
_chunk = self.config.get("misskey_download_chunk_size")
|
||||
try:
|
||||
self.download_chunk_size = int(_chunk) if _chunk is not None else 64 * 1024
|
||||
except Exception:
|
||||
self.download_chunk_size = 64 * 1024
|
||||
# parse max download bytes safely
|
||||
_md_bytes = self.config.get("misskey_max_download_bytes")
|
||||
try:
|
||||
self.max_download_bytes = int(_md_bytes) if _md_bytes is not None else None
|
||||
except Exception:
|
||||
self.max_download_bytes = None
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.api: Optional[MisskeyAPI] = None
|
||||
self._running = False
|
||||
self.client_self_id = ""
|
||||
self._bot_username = ""
|
||||
self._user_cache = {}
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
default_config = {
|
||||
"misskey_instance_url": "",
|
||||
"misskey_token": "",
|
||||
"max_message_length": 3000,
|
||||
"misskey_default_visibility": "public",
|
||||
"misskey_local_only": False,
|
||||
"misskey_enable_chat": True,
|
||||
# download / security options
|
||||
"misskey_allow_insecure_downloads": False,
|
||||
"misskey_download_timeout": 15,
|
||||
"misskey_download_chunk_size": 65536,
|
||||
"misskey_max_download_bytes": None,
|
||||
}
|
||||
default_config.update(self.config)
|
||||
|
||||
return PlatformMetadata(
|
||||
name="misskey",
|
||||
description="Misskey 平台适配器",
|
||||
id=self.config.get("id", "misskey"),
|
||||
default_config_tmpl=default_config,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
if not self.instance_url or not self.access_token:
|
||||
logger.error("[Misskey] 配置不完整,无法启动")
|
||||
return
|
||||
|
||||
self.api = MisskeyAPI(
|
||||
self.instance_url,
|
||||
self.access_token,
|
||||
allow_insecure_downloads=self.allow_insecure_downloads,
|
||||
download_timeout=self.download_timeout,
|
||||
chunk_size=self.download_chunk_size,
|
||||
max_download_bytes=self.max_download_bytes,
|
||||
)
|
||||
self._running = True
|
||||
|
||||
try:
|
||||
user_info = await self.api.get_current_user()
|
||||
self.client_self_id = str(user_info.get("id", ""))
|
||||
self._bot_username = user_info.get("username", "")
|
||||
logger.info(
|
||||
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 获取用户信息失败: {e}")
|
||||
self._running = False
|
||||
return
|
||||
|
||||
await self._start_websocket_connection()
|
||||
|
||||
def _register_event_handlers(self, streaming):
|
||||
"""注册事件处理器"""
|
||||
streaming.add_message_handler("notification", self._handle_notification)
|
||||
streaming.add_message_handler("main:notification", self._handle_notification)
|
||||
|
||||
if self.enable_chat:
|
||||
streaming.add_message_handler("newChatMessage", self._handle_chat_message)
|
||||
streaming.add_message_handler(
|
||||
"messaging:newChatMessage", self._handle_chat_message
|
||||
)
|
||||
streaming.add_message_handler("_debug", self._debug_handler)
|
||||
|
||||
async def _send_text_only_message(
|
||||
self, session_id: str, text: str, session, message_chain
|
||||
):
|
||||
"""发送纯文本消息(无文件上传)"""
|
||||
if not self.api:
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
if session_id and is_valid_user_session_id(session_id):
|
||||
from .misskey_utils import extract_user_id_from_session_id
|
||||
|
||||
user_id = extract_user_id_from_session_id(session_id)
|
||||
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
|
||||
await self.api.send_message(payload)
|
||||
elif session_id and is_valid_room_session_id(session_id):
|
||||
from .misskey_utils import extract_room_id_from_session_id
|
||||
|
||||
room_id = extract_room_id_from_session_id(session_id)
|
||||
payload = {"toRoomId": room_id, "text": text}
|
||||
await self.api.send_room_message(payload)
|
||||
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
def _process_poll_data(
|
||||
self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str]
|
||||
):
|
||||
"""处理投票数据,将其添加到消息中"""
|
||||
try:
|
||||
if not isinstance(message.raw_message, dict):
|
||||
message.raw_message = {}
|
||||
message.raw_message["poll"] = poll
|
||||
setattr(message, "poll", poll)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
poll_text = format_poll(poll)
|
||||
if poll_text:
|
||||
message.message.append(Comp.Plain(poll_text))
|
||||
message_parts.append(poll_text)
|
||||
|
||||
def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]:
|
||||
"""从会话和消息链中提取额外字段"""
|
||||
fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None}
|
||||
|
||||
for comp in message_chain.chain:
|
||||
if hasattr(comp, "cw") and getattr(comp, "cw", None):
|
||||
fields["cw"] = getattr(comp, "cw")
|
||||
break
|
||||
|
||||
if hasattr(session, "extra_data") and isinstance(
|
||||
getattr(session, "extra_data", None), dict
|
||||
):
|
||||
extra_data = getattr(session, "extra_data")
|
||||
fields.update(
|
||||
{
|
||||
"poll": extra_data.get("poll"),
|
||||
"renote_id": extra_data.get("renote_id"),
|
||||
"channel_id": extra_data.get("channel_id"),
|
||||
}
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
async def _start_websocket_connection(self):
|
||||
backoff_delay = 1.0
|
||||
max_backoff = 300.0
|
||||
backoff_multiplier = 1.5
|
||||
connection_attempts = 0
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
connection_attempts += 1
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
break
|
||||
|
||||
streaming = self.api.get_streaming_client()
|
||||
self._register_event_handlers(streaming)
|
||||
|
||||
if await streaming.connect():
|
||||
logger.info(
|
||||
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
|
||||
)
|
||||
connection_attempts = 0
|
||||
await streaming.subscribe_channel("main")
|
||||
if self.enable_chat:
|
||||
await streaming.subscribe_channel("messaging")
|
||||
await streaming.subscribe_channel("messagingIndex")
|
||||
logger.info("[Misskey] 聊天频道已订阅")
|
||||
|
||||
backoff_delay = 1.0
|
||||
await streaming.listen()
|
||||
else:
|
||||
logger.error(
|
||||
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
|
||||
)
|
||||
|
||||
if self._running:
|
||||
jitter = random.uniform(0, 1.0)
|
||||
sleep_time = backoff_delay + jitter
|
||||
logger.info(
|
||||
f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
||||
)
|
||||
await asyncio.sleep(sleep_time)
|
||||
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
||||
|
||||
async def _handle_notification(self, data: Dict[str, Any]):
|
||||
try:
|
||||
notification_type = data.get("type")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到通知事件: type={notification_type}, user_id={data.get('userId', 'unknown')}"
|
||||
)
|
||||
if notification_type in ["mention", "reply", "quote"]:
|
||||
note = data.get("note")
|
||||
if note and self._is_bot_mentioned(note):
|
||||
logger.info(
|
||||
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
|
||||
)
|
||||
message = await self.convert_message(note)
|
||||
event = MisskeyPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self,
|
||||
)
|
||||
self.commit_event(event)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理通知失败: {e}")
|
||||
|
||||
async def _handle_chat_message(self, data: Dict[str, Any]):
|
||||
try:
|
||||
sender_id = str(
|
||||
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
|
||||
)
|
||||
room_id = data.get("toRoomId")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到聊天事件: sender_id={sender_id}, room_id={room_id}, is_self={sender_id == self.client_self_id}"
|
||||
)
|
||||
if sender_id == self.client_self_id:
|
||||
return
|
||||
|
||||
if room_id:
|
||||
raw_text = data.get("text", "")
|
||||
logger.debug(
|
||||
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
|
||||
)
|
||||
|
||||
message = await self.convert_room_message(data)
|
||||
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
|
||||
else:
|
||||
message = await self.convert_chat_message(data)
|
||||
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
|
||||
|
||||
event = MisskeyPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self,
|
||||
)
|
||||
self.commit_event(event)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
||||
|
||||
async def _debug_handler(self, data: Dict[str, Any]):
|
||||
event_type = data.get("type", "unknown")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}"
|
||||
)
|
||||
|
||||
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
|
||||
text = note.get("text", "")
|
||||
if not text:
|
||||
return False
|
||||
|
||||
mentions = note.get("mentions", [])
|
||||
if self._bot_username and f"@{self._bot_username}" in text:
|
||||
return True
|
||||
if self.client_self_id in [str(uid) for uid in mentions]:
|
||||
return True
|
||||
|
||||
reply = note.get("reply")
|
||||
if reply and isinstance(reply, dict):
|
||||
reply_user_id = str(reply.get("user", {}).get("id", ""))
|
||||
if reply_user_id == self.client_self_id:
|
||||
return bool(self._bot_username and f"@{self._bot_username}" in text)
|
||||
|
||||
return False
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSession, message_chain: MessageChain
|
||||
) -> Awaitable[Any]:
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
try:
|
||||
session_id = session.session_id
|
||||
|
||||
text, has_at_user = serialize_message_chain(message_chain.chain)
|
||||
|
||||
if not has_at_user and session_id:
|
||||
# 从session_id中提取用户ID用于缓存查询
|
||||
# session_id格式为: "chat%<user_id>" 或 "room%<room_id>" 或 "note%<user_id>"
|
||||
user_id_for_cache = None
|
||||
if "%" in session_id:
|
||||
parts = session_id.split("%")
|
||||
if len(parts) >= 2:
|
||||
user_id_for_cache = parts[1]
|
||||
|
||||
user_info = None
|
||||
if user_id_for_cache:
|
||||
user_info = self._user_cache.get(user_id_for_cache)
|
||||
|
||||
text = add_at_mention_if_needed(text, user_info, has_at_user)
|
||||
|
||||
# 检查是否有文件组件
|
||||
has_file_components = any(
|
||||
isinstance(comp, Comp.Image)
|
||||
or isinstance(comp, Comp.File)
|
||||
or hasattr(comp, "convert_to_file_path")
|
||||
or hasattr(comp, "get_file")
|
||||
or any(
|
||||
hasattr(comp, a) for a in ("file", "url", "path", "src", "source")
|
||||
)
|
||||
for comp in message_chain.chain
|
||||
)
|
||||
|
||||
if not text or not text.strip():
|
||||
if not has_file_components:
|
||||
logger.warning("[Misskey] 消息内容为空且无文件组件,跳过发送")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if len(text) > self.max_message_length:
|
||||
text = text[: self.max_message_length] + "..."
|
||||
|
||||
file_ids: List[str] = []
|
||||
fallback_urls: List[str] = []
|
||||
|
||||
if not self.enable_file_upload:
|
||||
return await self._send_text_only_message(
|
||||
session_id, text, session, message_chain
|
||||
)
|
||||
|
||||
MAX_UPLOAD_CONCURRENCY = 10
|
||||
upload_concurrency = int(
|
||||
self.config.get(
|
||||
"misskey_upload_concurrency", DEFAULT_UPLOAD_CONCURRENCY
|
||||
)
|
||||
)
|
||||
upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY)
|
||||
sem = asyncio.Semaphore(upload_concurrency)
|
||||
|
||||
async def _upload_comp(comp) -> Optional[object]:
|
||||
"""组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)"""
|
||||
from .misskey_utils import (
|
||||
resolve_component_url_or_path,
|
||||
upload_local_with_retries,
|
||||
)
|
||||
|
||||
local_path = None
|
||||
try:
|
||||
async with sem:
|
||||
if not self.api:
|
||||
return None
|
||||
|
||||
# 解析组件的 URL 或本地路径
|
||||
url_candidate, local_path = await resolve_component_url_or_path(
|
||||
comp
|
||||
)
|
||||
|
||||
if not url_candidate and not local_path:
|
||||
return None
|
||||
|
||||
preferred_name = getattr(comp, "name", None) or getattr(
|
||||
comp, "file", None
|
||||
)
|
||||
|
||||
# URL 上传:下载后本地上传
|
||||
if url_candidate:
|
||||
result = await self.api.upload_and_find_file(
|
||||
str(url_candidate),
|
||||
preferred_name,
|
||||
folder_id=self.upload_folder,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("id"):
|
||||
return str(result["id"])
|
||||
|
||||
# 本地文件上传
|
||||
if local_path:
|
||||
file_id = await upload_local_with_retries(
|
||||
self.api,
|
||||
str(local_path),
|
||||
preferred_name,
|
||||
self.upload_folder,
|
||||
)
|
||||
if file_id:
|
||||
return file_id
|
||||
|
||||
# 所有上传都失败,尝试获取 URL 作为回退
|
||||
if hasattr(comp, "register_to_file_service"):
|
||||
try:
|
||||
url = await comp.register_to_file_service()
|
||||
if url:
|
||||
return {"fallback_url": url}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if local_path and isinstance(local_path, str):
|
||||
data_temp = os.path.join(get_astrbot_data_path(), "temp")
|
||||
if local_path.startswith(data_temp) and os.path.exists(
|
||||
local_path
|
||||
):
|
||||
try:
|
||||
os.remove(local_path)
|
||||
logger.debug(f"[Misskey] 已清理临时文件: {local_path}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 收集所有可能包含文件/URL信息的组件:支持异步接口或同步字段
|
||||
file_components = []
|
||||
for comp in message_chain.chain:
|
||||
try:
|
||||
if (
|
||||
isinstance(comp, Comp.Image)
|
||||
or isinstance(comp, Comp.File)
|
||||
or hasattr(comp, "convert_to_file_path")
|
||||
or hasattr(comp, "get_file")
|
||||
or any(
|
||||
hasattr(comp, a)
|
||||
for a in ("file", "url", "path", "src", "source")
|
||||
)
|
||||
):
|
||||
file_components.append(comp)
|
||||
except Exception:
|
||||
# 保守跳过无法访问属性的组件
|
||||
continue
|
||||
|
||||
if len(file_components) > MAX_FILE_UPLOAD_COUNT:
|
||||
logger.warning(
|
||||
f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件"
|
||||
)
|
||||
file_components = file_components[:MAX_FILE_UPLOAD_COUNT]
|
||||
|
||||
upload_tasks = [_upload_comp(comp) for comp in file_components]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*upload_tasks) if upload_tasks else []
|
||||
for r in results:
|
||||
if not r:
|
||||
continue
|
||||
if isinstance(r, dict) and r.get("fallback_url"):
|
||||
url = r.get("fallback_url")
|
||||
if url:
|
||||
fallback_urls.append(str(url))
|
||||
else:
|
||||
try:
|
||||
fid_str = str(r)
|
||||
if fid_str:
|
||||
file_ids.append(fid_str)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("[Misskey] 并发上传过程中出现异常,继续发送文本")
|
||||
|
||||
if session_id and is_valid_room_session_id(session_id):
|
||||
from .misskey_utils import extract_room_id_from_session_id
|
||||
|
||||
room_id = extract_room_id_from_session_id(session_id)
|
||||
if fallback_urls:
|
||||
appended = "\n" + "\n".join(fallback_urls)
|
||||
text = (text or "") + appended
|
||||
payload: Dict[str, Any] = {"toRoomId": room_id, "text": text}
|
||||
if file_ids:
|
||||
payload["fileIds"] = file_ids
|
||||
await self.api.send_room_message(payload)
|
||||
elif session_id:
|
||||
from .misskey_utils import (
|
||||
extract_user_id_from_session_id,
|
||||
is_valid_chat_session_id,
|
||||
)
|
||||
|
||||
if is_valid_chat_session_id(session_id):
|
||||
user_id = extract_user_id_from_session_id(session_id)
|
||||
if fallback_urls:
|
||||
appended = "\n" + "\n".join(fallback_urls)
|
||||
text = (text or "") + appended
|
||||
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
|
||||
if file_ids:
|
||||
# 聊天消息只支持单个文件,使用 fileId 而不是 fileIds
|
||||
payload["fileId"] = file_ids[0]
|
||||
if len(file_ids) > 1:
|
||||
logger.warning(
|
||||
f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件"
|
||||
)
|
||||
await self.api.send_message(payload)
|
||||
else:
|
||||
# 回退到发帖逻辑
|
||||
# 去掉 session_id 中的 note% 前缀以匹配 user_cache 的键格式
|
||||
user_id_for_cache = (
|
||||
session_id.split("%")[1] if "%" in session_id else session_id
|
||||
)
|
||||
|
||||
# 获取用户缓存信息(包含reply_to_note_id)
|
||||
user_info_for_reply = self._user_cache.get(user_id_for_cache, {})
|
||||
|
||||
visibility, visible_user_ids = resolve_message_visibility(
|
||||
user_id=user_id_for_cache,
|
||||
user_cache=self._user_cache,
|
||||
self_id=self.client_self_id,
|
||||
default_visibility=self.default_visibility,
|
||||
)
|
||||
logger.debug(
|
||||
f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}"
|
||||
)
|
||||
|
||||
fields = self._extract_additional_fields(session, message_chain)
|
||||
if fallback_urls:
|
||||
appended = "\n" + "\n".join(fallback_urls)
|
||||
text = (text or "") + appended
|
||||
|
||||
# 从缓存中获取原消息ID作为reply_id
|
||||
reply_id = user_info_for_reply.get("reply_to_note_id")
|
||||
|
||||
await self.api.create_note(
|
||||
text=text,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
file_ids=file_ids or None,
|
||||
local_only=self.local_only,
|
||||
reply_id=reply_id, # 添加reply_id参数
|
||||
cw=fields["cw"],
|
||||
poll=fields["poll"],
|
||||
renote_id=fields["renote_id"],
|
||||
channel_id=fields["channel_id"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 发送消息失败: {e}")
|
||||
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
|
||||
sender_info = extract_sender_info(raw_data, is_chat=False)
|
||||
message = create_base_message(
|
||||
raw_data,
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||
)
|
||||
|
||||
message_parts = []
|
||||
raw_text = raw_data.get("text", "")
|
||||
|
||||
if raw_text:
|
||||
text_parts, processed_text = process_at_mention(
|
||||
message, raw_text, self._bot_username, self.client_self_id
|
||||
)
|
||||
message_parts.extend(text_parts)
|
||||
|
||||
files = raw_data.get("files", [])
|
||||
file_parts = process_files(message, files)
|
||||
message_parts.extend(file_parts)
|
||||
|
||||
poll = raw_data.get("poll") or (
|
||||
raw_data.get("note", {}).get("poll")
|
||||
if isinstance(raw_data.get("note"), dict)
|
||||
else None
|
||||
)
|
||||
if poll and isinstance(poll, dict):
|
||||
self._process_poll_data(message, poll, message_parts)
|
||||
|
||||
message.message_str = (
|
||||
" ".join(part for part in message_parts if part.strip())
|
||||
if message_parts
|
||||
else ""
|
||||
)
|
||||
return message
|
||||
|
||||
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
|
||||
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||
message = create_base_message(
|
||||
raw_data,
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=True,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
|
||||
)
|
||||
|
||||
raw_text = raw_data.get("text", "")
|
||||
if raw_text:
|
||||
message.message.append(Comp.Plain(raw_text))
|
||||
|
||||
files = raw_data.get("files", [])
|
||||
process_files(message, files, include_text_parts=False)
|
||||
|
||||
message.message_str = raw_text if raw_text else ""
|
||||
return message
|
||||
|
||||
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
|
||||
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||
room_id = raw_data.get("toRoomId", "")
|
||||
message = create_base_message(
|
||||
raw_data,
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
room_id=room_id,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
|
||||
cache_user_info(
|
||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||
)
|
||||
cache_room_info(self._user_cache, raw_data, self.client_self_id)
|
||||
|
||||
raw_text = raw_data.get("text", "")
|
||||
message_parts = []
|
||||
|
||||
if raw_text:
|
||||
if self._bot_username and f"@{self._bot_username}" in raw_text:
|
||||
text_parts, processed_text = process_at_mention(
|
||||
message, raw_text, self._bot_username, self.client_self_id
|
||||
)
|
||||
message_parts.extend(text_parts)
|
||||
else:
|
||||
message.message.append(Comp.Plain(raw_text))
|
||||
message_parts.append(raw_text)
|
||||
|
||||
files = raw_data.get("files", [])
|
||||
file_parts = process_files(message, files)
|
||||
message_parts.extend(file_parts)
|
||||
|
||||
message.message_str = (
|
||||
" ".join(part for part in message_parts if part.strip())
|
||||
if message_parts
|
||||
else ""
|
||||
)
|
||||
return message
|
||||
|
||||
async def terminate(self):
|
||||
self._running = False
|
||||
if self.api:
|
||||
await self.api.close()
|
||||
|
||||
def get_client(self) -> Any:
|
||||
return self.api
|
||||
940
astrbot/core/platform/sources/misskey/misskey_api.py
Normal file
940
astrbot/core/platform/sources/misskey/misskey_api.py
Normal file
@@ -0,0 +1,940 @@
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
from typing import Any, Optional, Dict, List, Callable, Awaitable
|
||||
import uuid
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
import websockets
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
|
||||
) from e
|
||||
|
||||
from astrbot.api import logger
|
||||
from .misskey_utils import FileIDExtractor
|
||||
|
||||
# Constants
|
||||
API_MAX_RETRIES = 3
|
||||
HTTP_OK = 200
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""Misskey API 基础异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""网络连接异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIRateLimitError(APIError):
|
||||
"""API 频率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(APIError):
|
||||
"""认证失败异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketError(APIError):
|
||||
"""WebSocket 连接异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StreamingClient:
|
||||
def __init__(self, instance_url: str, access_token: str):
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self.websocket: Optional[Any] = None
|
||||
self.is_connected = False
|
||||
self.message_handlers: Dict[str, Callable] = {}
|
||||
self.channels: Dict[str, str] = {}
|
||||
self.desired_channels: Dict[str, Optional[Dict]] = {}
|
||||
self._running = False
|
||||
self._last_pong = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
try:
|
||||
ws_url = self.instance_url.replace("https://", "wss://").replace(
|
||||
"http://", "ws://"
|
||||
)
|
||||
ws_url += f"/streaming?i={self.access_token}"
|
||||
|
||||
self.websocket = await websockets.connect(
|
||||
ws_url, ping_interval=30, ping_timeout=10
|
||||
)
|
||||
self.is_connected = True
|
||||
self._running = True
|
||||
|
||||
logger.info("[Misskey WebSocket] 已连接")
|
||||
if self.desired_channels:
|
||||
try:
|
||||
desired = list(self.desired_channels.items())
|
||||
for channel_type, params in desired:
|
||||
try:
|
||||
await self.subscribe_channel(channel_type, params)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
self._running = False
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
self.is_connected = False
|
||||
logger.info("[Misskey WebSocket] 连接已断开")
|
||||
|
||||
async def subscribe_channel(
|
||||
self, channel_type: str, params: Optional[Dict] = None
|
||||
) -> str:
|
||||
if not self.is_connected or not self.websocket:
|
||||
raise WebSocketError("WebSocket 未连接")
|
||||
|
||||
channel_id = str(uuid.uuid4())
|
||||
message = {
|
||||
"type": "connect",
|
||||
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
|
||||
}
|
||||
|
||||
await self.websocket.send(json.dumps(message))
|
||||
self.channels[channel_id] = channel_type
|
||||
return channel_id
|
||||
|
||||
async def unsubscribe_channel(self, channel_id: str):
|
||||
if (
|
||||
not self.is_connected
|
||||
or not self.websocket
|
||||
or channel_id not in self.channels
|
||||
):
|
||||
return
|
||||
|
||||
message = {"type": "disconnect", "body": {"id": channel_id}}
|
||||
await self.websocket.send(json.dumps(message))
|
||||
channel_type = self.channels.get(channel_id)
|
||||
if channel_id in self.channels:
|
||||
del self.channels[channel_id]
|
||||
if channel_type and channel_type not in self.channels.values():
|
||||
self.desired_channels.pop(channel_type, None)
|
||||
|
||||
def add_message_handler(
|
||||
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
|
||||
):
|
||||
self.message_handlers[event_type] = handler
|
||||
|
||||
async def listen(self):
|
||||
if not self.is_connected or not self.websocket:
|
||||
raise WebSocketError("WebSocket 未连接")
|
||||
|
||||
try:
|
||||
async for message in self.websocket:
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._handle_message(data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(
|
||||
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
|
||||
)
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except websockets.exceptions.InvalidHandshake as e:
|
||||
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle_message(self, data: Dict[str, Any]):
|
||||
message_type = data.get("type")
|
||||
body = data.get("body", {})
|
||||
|
||||
def _build_channel_summary(message_type: Optional[str], body: Any) -> str:
|
||||
try:
|
||||
if not isinstance(body, dict):
|
||||
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
|
||||
|
||||
inner = body.get("body") if isinstance(body.get("body"), dict) else body
|
||||
note = (
|
||||
inner.get("note")
|
||||
if isinstance(inner, dict) and isinstance(inner.get("note"), dict)
|
||||
else None
|
||||
)
|
||||
|
||||
text = note.get("text") if note else None
|
||||
note_id = note.get("id") if note else None
|
||||
files = note.get("files") or [] if note else []
|
||||
has_files = bool(files)
|
||||
is_hidden = bool(note.get("isHidden")) if note else False
|
||||
user = note.get("user", {}) if note else None
|
||||
|
||||
return (
|
||||
f"[Misskey WebSocket] 收到消息类型: {message_type} | "
|
||||
f"note_id={note_id} | user={user.get('username') if user else None} | "
|
||||
f"text={text[:80] if text else '[no-text]'} | files={has_files} | hidden={is_hidden}"
|
||||
)
|
||||
except Exception:
|
||||
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
|
||||
|
||||
channel_summary = _build_channel_summary(message_type, body)
|
||||
logger.info(channel_summary)
|
||||
|
||||
if message_type == "channel":
|
||||
channel_id = body.get("id")
|
||||
event_type = body.get("type")
|
||||
event_body = body.get("body", {})
|
||||
|
||||
logger.debug(
|
||||
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
|
||||
)
|
||||
|
||||
if channel_id in self.channels:
|
||||
channel_type = self.channels[channel_id]
|
||||
handler_key = f"{channel_type}:{event_type}"
|
||||
|
||||
if handler_key in self.message_handlers:
|
||||
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
|
||||
await self.message_handlers[handler_key](event_body)
|
||||
elif event_type in self.message_handlers:
|
||||
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
|
||||
await self.message_handlers[event_type](event_body)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}"
|
||||
)
|
||||
if "_debug" in self.message_handlers:
|
||||
await self.message_handlers["_debug"](
|
||||
{
|
||||
"type": event_type,
|
||||
"body": event_body,
|
||||
"channel": channel_type,
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type in self.message_handlers:
|
||||
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
|
||||
await self.message_handlers[message_type](body)
|
||||
else:
|
||||
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
|
||||
if "_debug" in self.message_handlers:
|
||||
await self.message_handlers["_debug"](data)
|
||||
|
||||
|
||||
def retry_async(
|
||||
max_retries: int = 3,
|
||||
retryable_exceptions: tuple = (APIConnectionError, APIRateLimitError),
|
||||
backoff_base: float = 1.0,
|
||||
max_backoff: float = 30.0,
|
||||
):
|
||||
"""
|
||||
智能异步重试装饰器
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
retryable_exceptions: 可重试的异常类型
|
||||
backoff_base: 退避基数
|
||||
max_backoff: 最大退避时间
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
last_exc = None
|
||||
func_name = getattr(func, "__name__", "unknown")
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except retryable_exceptions as e:
|
||||
last_exc = e
|
||||
if attempt == max_retries:
|
||||
logger.error(
|
||||
f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
# 智能退避策略
|
||||
if isinstance(e, APIRateLimitError):
|
||||
# 频率限制用更长的退避时间
|
||||
backoff = min(backoff_base * (3**attempt), max_backoff)
|
||||
else:
|
||||
# 其他错误用指数退避
|
||||
backoff = min(backoff_base * (2**attempt), max_backoff)
|
||||
|
||||
jitter = random.uniform(0.1, 0.5) # 随机抖动
|
||||
sleep_time = backoff + jitter
|
||||
|
||||
logger.warning(
|
||||
f"[Misskey API] {func_name} 第 {attempt} 次重试失败: {e},"
|
||||
f"{sleep_time:.1f}s后重试"
|
||||
)
|
||||
await asyncio.sleep(sleep_time)
|
||||
continue
|
||||
except Exception as e:
|
||||
# 非可重试异常直接抛出
|
||||
logger.error(f"[Misskey API] {func_name} 遇到不可重试异常: {e}")
|
||||
raise
|
||||
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class MisskeyAPI:
|
||||
def __init__(
|
||||
self,
|
||||
instance_url: str,
|
||||
access_token: str,
|
||||
*,
|
||||
allow_insecure_downloads: bool = False,
|
||||
download_timeout: int = 15,
|
||||
chunk_size: int = 64 * 1024,
|
||||
max_download_bytes: Optional[int] = None,
|
||||
):
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self.streaming: Optional[StreamingClient] = None
|
||||
# download options
|
||||
self.allow_insecure_downloads = allow_insecure_downloads
|
||||
self.download_timeout = download_timeout
|
||||
self.chunk_size = chunk_size
|
||||
self.max_download_bytes = (
|
||||
int(max_download_bytes) if max_download_bytes is not None else None
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.streaming:
|
||||
await self.streaming.disconnect()
|
||||
self.streaming = None
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
logger.debug("[Misskey API] 客户端已关闭")
|
||||
|
||||
def get_streaming_client(self) -> StreamingClient:
|
||||
if not self.streaming:
|
||||
self.streaming = StreamingClient(self.instance_url, self.access_token)
|
||||
return self.streaming
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
headers = {"Authorization": f"Bearer {self.access_token}"}
|
||||
self._session = aiohttp.ClientSession(headers=headers)
|
||||
return self._session
|
||||
|
||||
def _handle_response_status(self, status: int, endpoint: str):
|
||||
"""处理 HTTP 响应状态码"""
|
||||
if status == 400:
|
||||
logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})")
|
||||
raise APIError(f"Bad request for {endpoint}")
|
||||
elif status == 401:
|
||||
logger.error(f"[Misskey API] 未授权访问: {endpoint} (HTTP {status})")
|
||||
raise AuthenticationError(f"Unauthorized access for {endpoint}")
|
||||
elif status == 403:
|
||||
logger.error(f"[Misskey API] 访问被禁止: {endpoint} (HTTP {status})")
|
||||
raise AuthenticationError(f"Forbidden access for {endpoint}")
|
||||
elif status == 404:
|
||||
logger.error(f"[Misskey API] 资源不存在: {endpoint} (HTTP {status})")
|
||||
raise APIError(f"Resource not found for {endpoint}")
|
||||
elif status == 413:
|
||||
logger.error(f"[Misskey API] 请求体过大: {endpoint} (HTTP {status})")
|
||||
raise APIError(f"Request entity too large for {endpoint}")
|
||||
elif status == 429:
|
||||
logger.warning(f"[Misskey API] 请求频率限制: {endpoint} (HTTP {status})")
|
||||
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
|
||||
elif status == 500:
|
||||
logger.error(f"[Misskey API] 服务器内部错误: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Internal server error for {endpoint}")
|
||||
elif status == 502:
|
||||
logger.error(f"[Misskey API] 网关错误: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Bad gateway for {endpoint}")
|
||||
elif status == 503:
|
||||
logger.error(f"[Misskey API] 服务不可用: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Service unavailable for {endpoint}")
|
||||
elif status == 504:
|
||||
logger.error(f"[Misskey API] 网关超时: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Gateway timeout for {endpoint}")
|
||||
else:
|
||||
logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"HTTP {status} for {endpoint}")
|
||||
|
||||
async def _process_response(
|
||||
self, response: aiohttp.ClientResponse, endpoint: str
|
||||
) -> Any:
|
||||
"""处理 API 响应"""
|
||||
if response.status == HTTP_OK:
|
||||
try:
|
||||
result = await response.json()
|
||||
if endpoint == "i/notifications":
|
||||
notifications_data = (
|
||||
result
|
||||
if isinstance(result, list)
|
||||
else result.get("notifications", [])
|
||||
if isinstance(result, dict)
|
||||
else []
|
||||
)
|
||||
if notifications_data:
|
||||
logger.debug(
|
||||
f"[Misskey API] 获取到 {len(notifications_data)} 条新通知"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[Misskey API] 请求成功: {endpoint}")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[Misskey API] 响应格式错误: {e}")
|
||||
raise APIConnectionError("Invalid JSON response") from e
|
||||
else:
|
||||
try:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}"
|
||||
)
|
||||
|
||||
self._handle_response_status(response.status, endpoint)
|
||||
raise APIConnectionError(f"Request failed for {endpoint}")
|
||||
|
||||
@retry_async(
|
||||
max_retries=API_MAX_RETRIES,
|
||||
retryable_exceptions=(APIConnectionError, APIRateLimitError),
|
||||
)
|
||||
async def _make_request(
|
||||
self, endpoint: str, data: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
url = f"{self.instance_url}/api/{endpoint}"
|
||||
payload = {"i": self.access_token}
|
||||
if data:
|
||||
payload.update(data)
|
||||
|
||||
try:
|
||||
async with self.session.post(url, json=payload) as response:
|
||||
return await self._process_response(response, endpoint)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[Misskey API] HTTP 请求错误: {e}")
|
||||
raise APIConnectionError(f"HTTP request failed: {e}") from e
|
||||
|
||||
async def create_note(
|
||||
self,
|
||||
text: Optional[str] = None,
|
||||
visibility: str = "public",
|
||||
reply_id: Optional[str] = None,
|
||||
visible_user_ids: Optional[List[str]] = None,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
local_only: bool = False,
|
||||
cw: Optional[str] = None,
|
||||
poll: Optional[Dict[str, Any]] = None,
|
||||
renote_id: Optional[str] = None,
|
||||
channel_id: Optional[str] = None,
|
||||
reaction_acceptance: Optional[str] = None,
|
||||
no_extract_mentions: Optional[bool] = None,
|
||||
no_extract_hashtags: Optional[bool] = None,
|
||||
no_extract_emojis: Optional[bool] = None,
|
||||
media_ids: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API."""
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
if text is not None:
|
||||
data["text"] = text
|
||||
|
||||
data["visibility"] = visibility
|
||||
data["localOnly"] = local_only
|
||||
|
||||
if reply_id:
|
||||
data["replyId"] = reply_id
|
||||
|
||||
if visible_user_ids and visibility == "specified":
|
||||
data["visibleUserIds"] = visible_user_ids
|
||||
|
||||
if file_ids:
|
||||
data["fileIds"] = file_ids
|
||||
if media_ids:
|
||||
data["mediaIds"] = media_ids
|
||||
|
||||
if cw is not None:
|
||||
data["cw"] = cw
|
||||
if poll is not None:
|
||||
data["poll"] = poll
|
||||
if renote_id is not None:
|
||||
data["renoteId"] = renote_id
|
||||
if channel_id is not None:
|
||||
data["channelId"] = channel_id
|
||||
if reaction_acceptance is not None:
|
||||
data["reactionAcceptance"] = reaction_acceptance
|
||||
if no_extract_mentions is not None:
|
||||
data["noExtractMentions"] = bool(no_extract_mentions)
|
||||
if no_extract_hashtags is not None:
|
||||
data["noExtractHashtags"] = bool(no_extract_hashtags)
|
||||
if no_extract_emojis is not None:
|
||||
data["noExtractEmojis"] = bool(no_extract_emojis)
|
||||
|
||||
result = await self._make_request("notes/create", data)
|
||||
note_id = (
|
||||
result.get("createdNote", {}).get("id", "unknown")
|
||||
if isinstance(result, dict)
|
||||
else "unknown"
|
||||
)
|
||||
logger.debug(f"[Misskey API] 发帖成功: {note_id}")
|
||||
return result
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
name: Optional[str] = None,
|
||||
folder_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload a file to Misskey drive/files/create and return a dict containing id and raw result."""
|
||||
if not file_path:
|
||||
raise APIError("No file path provided for upload")
|
||||
|
||||
url = f"{self.instance_url}/api/drive/files/create"
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("i", self.access_token)
|
||||
|
||||
try:
|
||||
filename = name or file_path.split("/")[-1]
|
||||
if folder_id:
|
||||
form.add_field("folderId", str(folder_id))
|
||||
|
||||
try:
|
||||
f = open(file_path, "rb")
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"[Misskey API] 本地文件不存在: {file_path}")
|
||||
raise APIError(f"File not found: {file_path}") from e
|
||||
|
||||
try:
|
||||
form.add_field("file", f, filename=filename)
|
||||
async with self.session.post(url, data=form) as resp:
|
||||
result = await self._process_response(resp, "drive/files/create")
|
||||
file_id = FileIDExtractor.extract_file_id(result)
|
||||
logger.debug(
|
||||
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}"
|
||||
)
|
||||
return {"id": file_id, "raw": result}
|
||||
finally:
|
||||
f.close()
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[Misskey API] 文件上传网络错误: {e}")
|
||||
raise APIConnectionError(f"Upload failed: {e}") from e
|
||||
|
||||
async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]:
|
||||
"""Find files by MD5 hash"""
|
||||
if not md5_hash:
|
||||
raise APIError("No MD5 hash provided for find-by-hash")
|
||||
|
||||
data = {"md5": md5_hash}
|
||||
|
||||
try:
|
||||
logger.debug(f"[Misskey API] find-by-hash 请求: md5={md5_hash}")
|
||||
result = await self._make_request("drive/files/find-by-hash", data)
|
||||
logger.debug(
|
||||
f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 根据哈希查找文件失败: {e}")
|
||||
raise
|
||||
|
||||
async def find_files_by_name(
|
||||
self, name: str, folder_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find files by name"""
|
||||
if not name:
|
||||
raise APIError("No name provided for find")
|
||||
|
||||
data: Dict[str, Any] = {"name": name}
|
||||
if folder_id:
|
||||
data["folderId"] = folder_id
|
||||
|
||||
try:
|
||||
logger.debug(f"[Misskey API] find 请求: name={name}, folder_id={folder_id}")
|
||||
result = await self._make_request("drive/files/find", data)
|
||||
logger.debug(
|
||||
f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 根据名称查找文件失败: {e}")
|
||||
raise
|
||||
|
||||
async def find_files(
|
||||
self,
|
||||
limit: int = 10,
|
||||
folder_id: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files with optional filters"""
|
||||
data: Dict[str, Any] = {"limit": limit}
|
||||
if folder_id is not None:
|
||||
data["folderId"] = folder_id
|
||||
if type is not None:
|
||||
data["type"] = type
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}"
|
||||
)
|
||||
result = await self._make_request("drive/files", data)
|
||||
logger.debug(
|
||||
f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 列表文件失败: {e}")
|
||||
raise
|
||||
|
||||
async def _download_with_existing_session(
|
||||
self, url: str, ssl_verify: bool = True
|
||||
) -> Optional[bytes]:
|
||||
"""使用现有会话下载文件"""
|
||||
if not (hasattr(self, "session") and self.session):
|
||||
raise APIConnectionError("No existing session available")
|
||||
|
||||
async with self.session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=15), ssl=ssl_verify
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.read()
|
||||
return None
|
||||
|
||||
async def _download_with_temp_session(
|
||||
self, url: str, ssl_verify: bool = True
|
||||
) -> Optional[bytes]:
|
||||
"""使用临时会话下载文件"""
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_verify)
|
||||
async with aiohttp.ClientSession(connector=connector) as temp_session:
|
||||
async with temp_session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.read()
|
||||
return None
|
||||
|
||||
async def upload_and_find_file(
|
||||
self,
|
||||
url: str,
|
||||
name: Optional[str] = None,
|
||||
folder_id: Optional[str] = None,
|
||||
max_wait_time: float = 30.0,
|
||||
check_interval: float = 2.0,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
简化的文件上传:尝试 URL 上传,失败则下载后本地上传
|
||||
|
||||
Args:
|
||||
url: 文件URL
|
||||
name: 文件名(可选)
|
||||
folder_id: 文件夹ID(可选)
|
||||
max_wait_time: 保留参数(未使用)
|
||||
check_interval: 保留参数(未使用)
|
||||
|
||||
Returns:
|
||||
包含文件ID和元信息的字典,失败时返回None
|
||||
"""
|
||||
if not url:
|
||||
raise APIError("URL不能为空")
|
||||
|
||||
# 通过本地上传获取即时文件 ID(下载文件 → 上传 → 返回 ID)
|
||||
try:
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# SSL 验证下载,失败则重试不验证 SSL
|
||||
tmp_bytes = None
|
||||
try:
|
||||
tmp_bytes = await self._download_with_existing_session(
|
||||
url, ssl_verify=True
|
||||
) or await self._download_with_temp_session(url, ssl_verify=True)
|
||||
except Exception as ssl_error:
|
||||
logger.debug(
|
||||
f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL"
|
||||
)
|
||||
try:
|
||||
tmp_bytes = await self._download_with_existing_session(
|
||||
url, ssl_verify=False
|
||||
) or await self._download_with_temp_session(url, ssl_verify=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if tmp_bytes:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmpf:
|
||||
tmpf.write(tmp_bytes)
|
||||
tmp_path = tmpf.name
|
||||
|
||||
try:
|
||||
result = await self.upload_file(tmp_path, name, folder_id)
|
||||
logger.debug(f"[Misskey API] 本地上传成功: {result.get('id')}")
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 本地上传失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def get_current_user(self) -> Dict[str, Any]:
|
||||
"""获取当前用户信息"""
|
||||
return await self._make_request("i", {})
|
||||
|
||||
async def send_message(
|
||||
self, user_id_or_payload: Any, text: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送聊天消息。
|
||||
|
||||
Accepts either (user_id: str, text: str) or a single dict payload prepared by caller.
|
||||
"""
|
||||
if isinstance(user_id_or_payload, dict):
|
||||
data = user_id_or_payload
|
||||
else:
|
||||
data = {"toUserId": user_id_or_payload, "text": text}
|
||||
|
||||
result = await self._make_request("chat/messages/create-to-user", data)
|
||||
message_id = result.get("id", "unknown")
|
||||
logger.debug(f"[Misskey API] 聊天消息发送成功: {message_id}")
|
||||
return result
|
||||
|
||||
async def send_room_message(
|
||||
self, room_id_or_payload: Any, text: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送房间消息。
|
||||
|
||||
Accepts either (room_id: str, text: str) or a single dict payload.
|
||||
"""
|
||||
if isinstance(room_id_or_payload, dict):
|
||||
data = room_id_or_payload
|
||||
else:
|
||||
data = {"toRoomId": room_id_or_payload, "text": text}
|
||||
|
||||
result = await self._make_request("chat/messages/create-to-room", data)
|
||||
message_id = result.get("id", "unknown")
|
||||
logger.debug(f"[Misskey API] 房间消息发送成功: {message_id}")
|
||||
return result
|
||||
|
||||
async def get_messages(
|
||||
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取聊天消息历史"""
|
||||
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
|
||||
if since_id:
|
||||
data["sinceId"] = since_id
|
||||
|
||||
result = await self._make_request("chat/messages/user-timeline", data)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
logger.warning(f"[Misskey API] 聊天消息响应格式异常: {type(result)}")
|
||||
return []
|
||||
|
||||
async def get_mentions(
|
||||
self, limit: int = 10, since_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取提及通知"""
|
||||
data: Dict[str, Any] = {"limit": limit}
|
||||
if since_id:
|
||||
data["sinceId"] = since_id
|
||||
data["includeTypes"] = ["mention", "reply", "quote"]
|
||||
|
||||
result = await self._make_request("i/notifications", data)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
elif isinstance(result, dict) and "notifications" in result:
|
||||
return result["notifications"]
|
||||
else:
|
||||
logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}")
|
||||
return []
|
||||
|
||||
async def send_message_with_media(
|
||||
self,
|
||||
message_type: str,
|
||||
target_id: str,
|
||||
text: Optional[str] = None,
|
||||
media_urls: Optional[List[str]] = None,
|
||||
local_files: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通用消息发送函数:统一处理文本+媒体发送
|
||||
|
||||
Args:
|
||||
message_type: 消息类型 ('chat', 'room', 'note')
|
||||
target_id: 目标ID (用户ID/房间ID/频道ID等)
|
||||
text: 文本内容
|
||||
media_urls: 媒体文件URL列表
|
||||
local_files: 本地文件路径列表
|
||||
**kwargs: 其他参数(如visibility等)
|
||||
|
||||
Returns:
|
||||
发送结果字典
|
||||
|
||||
Raises:
|
||||
APIError: 参数错误或发送失败
|
||||
"""
|
||||
if not text and not media_urls and not local_files:
|
||||
raise APIError("消息内容不能为空:需要文本或媒体文件")
|
||||
|
||||
file_ids = []
|
||||
|
||||
# 处理远程媒体文件
|
||||
if media_urls:
|
||||
file_ids.extend(await self._process_media_urls(media_urls))
|
||||
|
||||
# 处理本地文件
|
||||
if local_files:
|
||||
file_ids.extend(await self._process_local_files(local_files))
|
||||
|
||||
# 根据消息类型发送
|
||||
return await self._dispatch_message(
|
||||
message_type, target_id, text, file_ids, **kwargs
|
||||
)
|
||||
|
||||
async def _process_media_urls(self, urls: List[str]) -> List[str]:
|
||||
"""处理远程媒体文件URL列表,返回文件ID列表"""
|
||||
file_ids = []
|
||||
for url in urls:
|
||||
try:
|
||||
result = await self.upload_and_find_file(url)
|
||||
if result and result.get("id"):
|
||||
file_ids.append(result["id"])
|
||||
logger.debug(f"[Misskey API] URL媒体上传成功: {result['id']}")
|
||||
else:
|
||||
logger.error(f"[Misskey API] URL媒体上传失败: {url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] URL媒体处理失败 {url}: {e}")
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
continue
|
||||
return file_ids
|
||||
|
||||
async def _process_local_files(self, file_paths: List[str]) -> List[str]:
|
||||
"""处理本地文件路径列表,返回文件ID列表"""
|
||||
file_ids = []
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
result = await self.upload_file(file_path)
|
||||
if result and result.get("id"):
|
||||
file_ids.append(result["id"])
|
||||
logger.debug(f"[Misskey API] 本地文件上传成功: {result['id']}")
|
||||
else:
|
||||
logger.error(f"[Misskey API] 本地文件上传失败: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 本地文件处理失败 {file_path}: {e}")
|
||||
continue
|
||||
return file_ids
|
||||
|
||||
async def _dispatch_message(
|
||||
self,
|
||||
message_type: str,
|
||||
target_id: str,
|
||||
text: Optional[str],
|
||||
file_ids: List[str],
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""根据消息类型分发到对应的发送方法"""
|
||||
if message_type == "chat":
|
||||
# 聊天消息使用 fileId (单数)
|
||||
payload = {"toUserId": target_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_ids:
|
||||
if len(file_ids) == 1:
|
||||
payload["fileId"] = file_ids[0]
|
||||
else:
|
||||
# 多文件时逐个发送
|
||||
results = []
|
||||
for file_id in file_ids:
|
||||
single_payload = payload.copy()
|
||||
single_payload["fileId"] = file_id
|
||||
result = await self.send_message(single_payload)
|
||||
results.append(result)
|
||||
return {"multiple": True, "results": results}
|
||||
return await self.send_message(payload)
|
||||
|
||||
elif message_type == "room":
|
||||
# 房间消息使用 fileId (单数)
|
||||
payload = {"toRoomId": target_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_ids:
|
||||
if len(file_ids) == 1:
|
||||
payload["fileId"] = file_ids[0]
|
||||
else:
|
||||
# 多文件时逐个发送
|
||||
results = []
|
||||
for file_id in file_ids:
|
||||
single_payload = payload.copy()
|
||||
single_payload["fileId"] = file_id
|
||||
result = await self.send_room_message(single_payload)
|
||||
results.append(result)
|
||||
return {"multiple": True, "results": results}
|
||||
return await self.send_room_message(payload)
|
||||
|
||||
elif message_type == "note":
|
||||
# 发帖使用 fileIds (复数)
|
||||
note_kwargs = {
|
||||
"text": text,
|
||||
"file_ids": file_ids or None,
|
||||
}
|
||||
# 合并其他参数
|
||||
note_kwargs.update(kwargs)
|
||||
return await self.create_note(**note_kwargs)
|
||||
|
||||
else:
|
||||
raise APIError(f"不支持的消息类型: {message_type}")
|
||||
158
astrbot/core/platform/sources/misskey/misskey_event.py
Normal file
158
astrbot/core/platform/sources/misskey/misskey_event.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
|
||||
from astrbot.api.message_components import Plain
|
||||
|
||||
from .misskey_utils import (
|
||||
serialize_message_chain,
|
||||
resolve_visibility_from_raw_message,
|
||||
is_valid_user_session_id,
|
||||
is_valid_room_session_id,
|
||||
add_at_mention_if_needed,
|
||||
extract_user_id_from_session_id,
|
||||
extract_room_id_from_session_id,
|
||||
)
|
||||
|
||||
|
||||
class MisskeyPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
def _is_system_command(self, message_str: str) -> bool:
|
||||
"""检测是否为系统指令"""
|
||||
if not message_str or not message_str.strip():
|
||||
return False
|
||||
|
||||
system_prefixes = ["/", "!", "#", ".", "^"]
|
||||
message_trimmed = message_str.strip()
|
||||
|
||||
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息,使用适配器的完整上传和发送逻辑"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件"
|
||||
)
|
||||
|
||||
# 使用适配器的 send_by_session 方法,它包含文件上传逻辑
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
# 根据session_id类型确定消息类型
|
||||
if is_valid_user_session_id(self.session_id):
|
||||
message_type = MessageType.FRIEND_MESSAGE
|
||||
elif is_valid_room_session_id(self.session_id):
|
||||
message_type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
message_type = MessageType.FRIEND_MESSAGE # 默认
|
||||
|
||||
session = MessageSession(
|
||||
platform_name=self.platform_meta.name,
|
||||
message_type=message_type,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}"
|
||||
)
|
||||
|
||||
# 调用适配器的 send_by_session 方法
|
||||
if hasattr(self.client, "send_by_session"):
|
||||
logger.debug("[MisskeyEvent] 调用适配器的 send_by_session 方法")
|
||||
await self.client.send_by_session(session, message)
|
||||
else:
|
||||
# 回退到原来的简化发送逻辑
|
||||
content, has_at = serialize_message_chain(message.chain)
|
||||
|
||||
if not content:
|
||||
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
||||
return
|
||||
|
||||
original_message_id = getattr(self.message_obj, "message_id", None)
|
||||
raw_message = getattr(self.message_obj, "raw_message", {})
|
||||
|
||||
if raw_message and not has_at:
|
||||
user_data = raw_message.get("user", {})
|
||||
user_info = {
|
||||
"username": user_data.get("username", ""),
|
||||
"nickname": user_data.get(
|
||||
"name", user_data.get("username", "")
|
||||
),
|
||||
}
|
||||
content = add_at_mention_if_needed(content, user_info, has_at)
|
||||
|
||||
# 根据会话类型选择发送方式
|
||||
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
||||
self.session_id
|
||||
):
|
||||
user_id = extract_user_id_from_session_id(self.session_id)
|
||||
await self.client.send_message(user_id, content)
|
||||
elif hasattr(
|
||||
self.client, "send_room_message"
|
||||
) and is_valid_room_session_id(self.session_id):
|
||||
room_id = extract_room_id_from_session_id(self.session_id)
|
||||
await self.client.send_room_message(room_id, content)
|
||||
elif original_message_id and hasattr(self.client, "create_note"):
|
||||
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
||||
raw_message
|
||||
)
|
||||
await self.client.create_note(
|
||||
content,
|
||||
reply_id=original_message_id,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
)
|
||||
elif hasattr(self.client, "create_note"):
|
||||
logger.debug("[MisskeyEvent] 创建新帖子")
|
||||
await self.client.create_note(content)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MisskeyEvent] 发送失败: {e}")
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
538
astrbot/core/platform/sources/misskey/misskey_utils.py
Normal file
538
astrbot/core/platform/sources/misskey/misskey_utils.py
Normal file
@@ -0,0 +1,538 @@
|
||||
"""Misskey 平台适配器通用工具函数"""
|
||||
|
||||
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
|
||||
|
||||
class FileIDExtractor:
|
||||
"""从 API 响应中提取文件 ID 的帮助类(无状态)。"""
|
||||
|
||||
@staticmethod
|
||||
def extract_file_id(result: Any) -> Optional[str]:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
|
||||
id_paths = [
|
||||
lambda r: r.get("createdFile", {}).get("id"),
|
||||
lambda r: r.get("file", {}).get("id"),
|
||||
lambda r: r.get("id"),
|
||||
]
|
||||
|
||||
for p in id_paths:
|
||||
try:
|
||||
if fid := p(result):
|
||||
return fid
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MessagePayloadBuilder:
|
||||
"""构建不同类型消息负载的帮助类(无状态)。"""
|
||||
|
||||
@staticmethod
|
||||
def build_chat_payload(
|
||||
user_id: str, text: Optional[str], file_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"toUserId": user_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_id:
|
||||
payload["fileId"] = file_id
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def build_room_payload(
|
||||
room_id: str, text: Optional[str], file_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"toRoomId": room_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_id:
|
||||
payload["fileId"] = file_id
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def build_note_payload(
|
||||
text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_ids:
|
||||
payload["fileIds"] = file_ids
|
||||
payload |= kwargs
|
||||
return payload
|
||||
|
||||
|
||||
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||
"""将消息链序列化为文本字符串"""
|
||||
text_parts = []
|
||||
has_at = False
|
||||
|
||||
def process_component(component):
|
||||
nonlocal has_at
|
||||
if isinstance(component, Comp.Plain):
|
||||
return component.text
|
||||
elif isinstance(component, Comp.File):
|
||||
# 为文件组件返回占位符,但适配器仍会处理原组件
|
||||
return "[文件]"
|
||||
elif isinstance(component, Comp.Image):
|
||||
# 为图片组件返回占位符,但适配器仍会处理原组件
|
||||
return "[图片]"
|
||||
elif isinstance(component, Comp.At):
|
||||
has_at = True
|
||||
# 优先使用name字段(用户名),如果没有则使用qq字段
|
||||
# 这样可以避免在Misskey中生成 @<user_id> 这样的无效提及
|
||||
if hasattr(component, "name") and component.name:
|
||||
return f"@{component.name}"
|
||||
else:
|
||||
return f"@{component.qq}"
|
||||
elif hasattr(component, "text"):
|
||||
text = getattr(component, "text", "")
|
||||
if "@" in text:
|
||||
has_at = True
|
||||
return text
|
||||
else:
|
||||
return str(component)
|
||||
|
||||
for component in chain:
|
||||
if isinstance(component, Comp.Node) and component.content:
|
||||
for node_comp in component.content:
|
||||
result = process_component(node_comp)
|
||||
if result:
|
||||
text_parts.append(result)
|
||||
else:
|
||||
result = process_component(component)
|
||||
if result:
|
||||
text_parts.append(result)
|
||||
|
||||
return "".join(text_parts), has_at
|
||||
|
||||
|
||||
def resolve_message_visibility(
|
||||
user_id: Optional[str] = None,
|
||||
user_cache: Optional[Dict[str, Any]] = None,
|
||||
self_id: Optional[str] = None,
|
||||
raw_message: Optional[Dict[str, Any]] = None,
|
||||
default_visibility: str = "public",
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""解析 Misskey 消息的可见性设置
|
||||
|
||||
可以从 user_cache 或 raw_message 中解析,支持两种调用方式:
|
||||
1. 基于 user_cache: resolve_message_visibility(user_id, user_cache, self_id)
|
||||
2. 基于 raw_message: resolve_message_visibility(raw_message=raw_message, self_id=self_id)
|
||||
"""
|
||||
visibility = default_visibility
|
||||
visible_user_ids = None
|
||||
|
||||
# 优先从 user_cache 解析
|
||||
if user_id and user_cache:
|
||||
user_info = user_cache.get(user_id)
|
||||
if user_info:
|
||||
original_visibility = user_info.get("visibility", default_visibility)
|
||||
if original_visibility == "specified":
|
||||
visibility = "specified"
|
||||
original_visible_users = user_info.get("visible_user_ids", [])
|
||||
users_to_include = [user_id]
|
||||
if self_id:
|
||||
users_to_include.append(self_id)
|
||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
return visibility, visible_user_ids
|
||||
|
||||
# 回退到从 raw_message 解析
|
||||
if raw_message:
|
||||
original_visibility = raw_message.get("visibility", default_visibility)
|
||||
if original_visibility == "specified":
|
||||
visibility = "specified"
|
||||
original_visible_users = raw_message.get("visibleUserIds", [])
|
||||
sender_id = raw_message.get("userId", "")
|
||||
|
||||
users_to_include = []
|
||||
if sender_id:
|
||||
users_to_include.append(sender_id)
|
||||
if self_id:
|
||||
users_to_include.append(self_id)
|
||||
|
||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
|
||||
return visibility, visible_user_ids
|
||||
|
||||
|
||||
# 保留旧函数名作为向后兼容的别名
|
||||
def resolve_visibility_from_raw_message(
|
||||
raw_message: Dict[str, Any], self_id: Optional[str] = None
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)"""
|
||||
return resolve_message_visibility(raw_message=raw_message, self_id=self_id)
|
||||
|
||||
|
||||
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
|
||||
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
|
||||
if not isinstance(session_id, str) or "%" not in session_id:
|
||||
return False
|
||||
|
||||
parts = session_id.split("%")
|
||||
return (
|
||||
len(parts) == 2
|
||||
and parts[0] == "chat"
|
||||
and bool(parts[1])
|
||||
and parts[1] != "unknown"
|
||||
)
|
||||
|
||||
|
||||
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
|
||||
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
|
||||
if not isinstance(session_id, str) or "%" not in session_id:
|
||||
return False
|
||||
|
||||
parts = session_id.split("%")
|
||||
return (
|
||||
len(parts) == 2
|
||||
and parts[0] == "room"
|
||||
and bool(parts[1])
|
||||
and parts[1] != "unknown"
|
||||
)
|
||||
|
||||
|
||||
def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool:
|
||||
"""检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)"""
|
||||
if not isinstance(session_id, str) or "%" not in session_id:
|
||||
return False
|
||||
|
||||
parts = session_id.split("%")
|
||||
return (
|
||||
len(parts) == 2
|
||||
and parts[0] == "chat"
|
||||
and bool(parts[1])
|
||||
and parts[1] != "unknown"
|
||||
)
|
||||
|
||||
|
||||
def extract_user_id_from_session_id(session_id: str) -> str:
|
||||
"""从 session_id 中提取用户 ID"""
|
||||
if "%" in session_id:
|
||||
parts = session_id.split("%")
|
||||
if len(parts) >= 2:
|
||||
return parts[1]
|
||||
return session_id
|
||||
|
||||
|
||||
def extract_room_id_from_session_id(session_id: str) -> str:
|
||||
"""从 session_id 中提取房间 ID"""
|
||||
if "%" in session_id:
|
||||
parts = session_id.split("%")
|
||||
if len(parts) >= 2 and parts[0] == "room":
|
||||
return parts[1]
|
||||
return session_id
|
||||
|
||||
|
||||
def add_at_mention_if_needed(
|
||||
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
|
||||
) -> str:
|
||||
"""如果需要且没有@用户,则添加@用户
|
||||
|
||||
注意:仅在有有效的username时才添加@提及,避免使用用户ID
|
||||
"""
|
||||
if has_at or not user_info:
|
||||
return text
|
||||
|
||||
username = user_info.get("username")
|
||||
# 如果没有username,则不添加@提及,返回原文本
|
||||
# 这样可以避免生成 @<user_id> 这样的无效提及
|
||||
if not username:
|
||||
return text
|
||||
|
||||
mention = f"@{username}"
|
||||
if not text.startswith(mention):
|
||||
text = f"{mention}\n{text}".strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
|
||||
"""创建文件组件和描述文本"""
|
||||
file_url = file_info.get("url", "")
|
||||
file_name = file_info.get("name", "未知文件")
|
||||
file_type = file_info.get("type", "")
|
||||
|
||||
if file_type.startswith("image/"):
|
||||
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
|
||||
elif file_type.startswith("audio/"):
|
||||
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
|
||||
elif file_type.startswith("video/"):
|
||||
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
|
||||
else:
|
||||
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
|
||||
|
||||
|
||||
def process_files(
|
||||
message: AstrBotMessage, files: list, include_text_parts: bool = True
|
||||
) -> list:
|
||||
"""处理文件列表,添加到消息组件中并返回文本描述"""
|
||||
file_parts = []
|
||||
for file_info in files:
|
||||
component, part_text = create_file_component(file_info)
|
||||
message.message.append(component)
|
||||
if include_text_parts:
|
||||
file_parts.append(part_text)
|
||||
return file_parts
|
||||
|
||||
|
||||
def format_poll(poll: Dict[str, Any]) -> str:
|
||||
"""将 Misskey 的 poll 对象格式化为可读字符串。"""
|
||||
if not poll or not isinstance(poll, dict):
|
||||
return ""
|
||||
multiple = poll.get("multiple", False)
|
||||
choices = poll.get("choices", [])
|
||||
text_choices = [
|
||||
f"({idx}) {c.get('text', '')} [{c.get('votes', 0)}票]"
|
||||
for idx, c in enumerate(choices, start=1)
|
||||
]
|
||||
parts = ["[投票]", ("允许多选" if multiple else "单选")] + (
|
||||
["选项: " + ", ".join(text_choices)] if text_choices else []
|
||||
)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def extract_sender_info(
|
||||
raw_data: Dict[str, Any], is_chat: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""提取发送者信息"""
|
||||
if is_chat:
|
||||
sender = raw_data.get("fromUser", {})
|
||||
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
|
||||
else:
|
||||
sender = raw_data.get("user", {})
|
||||
sender_id = str(sender.get("id", ""))
|
||||
|
||||
return {
|
||||
"sender": sender,
|
||||
"sender_id": sender_id,
|
||||
"nickname": sender.get("name", sender.get("username", "")),
|
||||
"username": sender.get("username", ""),
|
||||
}
|
||||
|
||||
|
||||
def create_base_message(
|
||||
raw_data: Dict[str, Any],
|
||||
sender_info: Dict[str, Any],
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
room_id: Optional[str] = None,
|
||||
unique_session: bool = False,
|
||||
) -> AstrBotMessage:
|
||||
"""创建基础消息对象"""
|
||||
message = AstrBotMessage()
|
||||
message.raw_message = raw_data
|
||||
message.message = []
|
||||
|
||||
message.sender = MessageMember(
|
||||
user_id=sender_info["sender_id"],
|
||||
nickname=sender_info["nickname"],
|
||||
)
|
||||
|
||||
if room_id:
|
||||
session_prefix = "room"
|
||||
session_id = f"{session_prefix}%{room_id}"
|
||||
if unique_session:
|
||||
session_id += f"_{sender_info['sender_id']}"
|
||||
message.type = MessageType.GROUP_MESSAGE
|
||||
message.group_id = room_id
|
||||
elif is_chat:
|
||||
session_prefix = "chat"
|
||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
else:
|
||||
session_prefix = "note"
|
||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||
message.type = MessageType.OTHER_MESSAGE
|
||||
|
||||
message.session_id = (
|
||||
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
|
||||
)
|
||||
message.message_id = str(raw_data.get("id", ""))
|
||||
message.self_id = client_self_id
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def process_at_mention(
|
||||
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
|
||||
) -> Tuple[List[str], str]:
|
||||
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
|
||||
message_parts = []
|
||||
|
||||
if not raw_text:
|
||||
return message_parts, ""
|
||||
|
||||
if bot_username and raw_text.startswith(f"@{bot_username}"):
|
||||
at_mention = f"@{bot_username}"
|
||||
message.message.append(Comp.At(qq=client_self_id))
|
||||
remaining_text = raw_text[len(at_mention) :].strip()
|
||||
if remaining_text:
|
||||
message.message.append(Comp.Plain(remaining_text))
|
||||
message_parts.append(remaining_text)
|
||||
return message_parts, remaining_text
|
||||
else:
|
||||
message.message.append(Comp.Plain(raw_text))
|
||||
message_parts.append(raw_text)
|
||||
return message_parts, raw_text
|
||||
|
||||
|
||||
def cache_user_info(
|
||||
user_cache: Dict[str, Any],
|
||||
sender_info: Dict[str, Any],
|
||||
raw_data: Dict[str, Any],
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
):
|
||||
"""缓存用户信息"""
|
||||
if is_chat:
|
||||
user_cache_data = {
|
||||
"username": sender_info["username"],
|
||||
"nickname": sender_info["nickname"],
|
||||
"visibility": "specified",
|
||||
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
|
||||
}
|
||||
else:
|
||||
user_cache_data = {
|
||||
"username": sender_info["username"],
|
||||
"nickname": sender_info["nickname"],
|
||||
"visibility": raw_data.get("visibility", "public"),
|
||||
"visible_user_ids": raw_data.get("visibleUserIds", []),
|
||||
# 保存原消息ID,用于回复时作为reply_id
|
||||
"reply_to_note_id": raw_data.get("id"),
|
||||
}
|
||||
|
||||
user_cache[sender_info["sender_id"]] = user_cache_data
|
||||
|
||||
|
||||
def cache_room_info(
|
||||
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
|
||||
):
|
||||
"""缓存房间信息"""
|
||||
room_data = raw_data.get("toRoom")
|
||||
room_id = raw_data.get("toRoomId")
|
||||
|
||||
if room_data and room_id:
|
||||
room_cache_key = f"room:{room_id}"
|
||||
user_cache[room_cache_key] = {
|
||||
"room_id": room_id,
|
||||
"room_name": room_data.get("name", ""),
|
||||
"room_description": room_data.get("description", ""),
|
||||
"owner_id": room_data.get("ownerId", ""),
|
||||
"visibility": "specified",
|
||||
"visible_user_ids": [client_self_id],
|
||||
}
|
||||
|
||||
|
||||
async def resolve_component_url_or_path(
|
||||
comp: Any,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""尝试从组件解析可上传的远程 URL 或本地路径。
|
||||
|
||||
返回 (url_candidate, local_path)。两者可能都为 None。
|
||||
这个函数尽量不抛异常,调用方可按需处理 None。
|
||||
"""
|
||||
url_candidate = None
|
||||
local_path = None
|
||||
|
||||
async def _get_str_value(coro_or_val):
|
||||
"""辅助函数:统一处理协程或普通值"""
|
||||
try:
|
||||
if hasattr(coro_or_val, "__await__"):
|
||||
result = await coro_or_val
|
||||
else:
|
||||
result = coro_or_val
|
||||
return result if isinstance(result, str) else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 尝试异步方法
|
||||
for method in ["convert_to_file_path", "get_file", "register_to_file_service"]:
|
||||
if not hasattr(comp, method):
|
||||
continue
|
||||
try:
|
||||
value = await _get_str_value(getattr(comp, method)())
|
||||
if value:
|
||||
if value.startswith("http"):
|
||||
url_candidate = value
|
||||
break
|
||||
else:
|
||||
local_path = value
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2. 尝试 get_file(True) 获取可直接访问的 URL
|
||||
if not url_candidate and hasattr(comp, "get_file"):
|
||||
try:
|
||||
value = await _get_str_value(comp.get_file(True))
|
||||
if value and value.startswith("http"):
|
||||
url_candidate = value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. 回退到同步属性
|
||||
if not url_candidate and not local_path:
|
||||
for attr in ("file", "url", "path", "src", "source"):
|
||||
try:
|
||||
value = getattr(comp, attr, None)
|
||||
if value and isinstance(value, str):
|
||||
if value.startswith("http"):
|
||||
url_candidate = value
|
||||
break
|
||||
else:
|
||||
local_path = value
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return url_candidate, local_path
|
||||
|
||||
|
||||
def summarize_component_for_log(comp: Any) -> Dict[str, Any]:
|
||||
"""生成适合日志的组件属性字典(尽量不抛异常)。"""
|
||||
attrs = {}
|
||||
for a in ("file", "url", "path", "src", "source", "name"):
|
||||
try:
|
||||
v = getattr(comp, a, None)
|
||||
if v is not None:
|
||||
attrs[a] = v
|
||||
except Exception:
|
||||
continue
|
||||
return attrs
|
||||
|
||||
|
||||
async def upload_local_with_retries(
|
||||
api: Any,
|
||||
local_path: str,
|
||||
preferred_name: Optional[str],
|
||||
folder_id: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。"""
|
||||
try:
|
||||
res = await api.upload_file(local_path, preferred_name, folder_id)
|
||||
if isinstance(res, dict):
|
||||
fid = res.get("id") or (res.get("raw") or {}).get("createdFile", {}).get(
|
||||
"id"
|
||||
)
|
||||
if fid:
|
||||
return str(fid)
|
||||
except Exception:
|
||||
# 上传失败,直接返回 None,让上层处理错误
|
||||
return None
|
||||
|
||||
return None
|
||||
@@ -15,12 +15,13 @@ class QQOfficialWebhook:
|
||||
self.appid = config["appid"]
|
||||
self.secret = config["secret"]
|
||||
self.port = config.get("port", 6196)
|
||||
self.is_sandbox = config.get("is_sandbox", False)
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
|
||||
if isinstance(self.port, str):
|
||||
self.port = int(self.port)
|
||||
|
||||
self.http: BotHttp = BotHttp(timeout=300)
|
||||
self.http: BotHttp = BotHttp(timeout=300, is_sandbox=self.is_sandbox)
|
||||
self.api: BotAPI = BotAPI(http=self.http)
|
||||
self.token = Token(self.appid, self.secret)
|
||||
|
||||
|
||||
@@ -17,7 +17,14 @@ from astrbot.api.platform import (
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
Reply,
|
||||
)
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
|
||||
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
|
||||
)
|
||||
self.token = self.config.get("satori_token", "")
|
||||
self.endpoint = self.config.get(
|
||||
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
|
||||
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
|
||||
)
|
||||
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
||||
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
||||
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="satori",
|
||||
description="Satori 通用协议适配器",
|
||||
id=self.config["id"],
|
||||
)
|
||||
|
||||
self.ws: Optional[ClientConnection] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.sequence = 0
|
||||
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
|
||||
return self.metadata
|
||||
|
||||
def _is_websocket_closed(self, ws) -> bool:
|
||||
"""检查WebSocket连接是否已关闭"""
|
||||
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
|
||||
|
||||
abm.self_id = login.get("user", {}).get("id", "")
|
||||
|
||||
content = message.get("content", "")
|
||||
abm.message = await self.parse_satori_elements(content)
|
||||
# 消息链
|
||||
abm.message = []
|
||||
|
||||
content = message.get("content", "")
|
||||
|
||||
quote = message.get("quote")
|
||||
content_for_parsing = content # 副本
|
||||
|
||||
# 提取<quote>标签
|
||||
if "<quote" in content:
|
||||
try:
|
||||
quote_info = await self._extract_quote_element(content)
|
||||
if quote_info:
|
||||
quote = quote_info["quote"]
|
||||
content_for_parsing = quote_info["content_without_quote"]
|
||||
except Exception as e:
|
||||
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
|
||||
|
||||
if quote:
|
||||
# 引用消息
|
||||
quote_abm = await self._convert_quote_message(quote)
|
||||
if quote_abm:
|
||||
sender_id = quote_abm.sender.user_id
|
||||
if isinstance(sender_id, str) and sender_id.isdigit():
|
||||
sender_id = int(sender_id)
|
||||
elif not isinstance(sender_id, int):
|
||||
sender_id = 0 # 默认值
|
||||
|
||||
reply_component = Reply(
|
||||
id=quote_abm.message_id,
|
||||
chain=quote_abm.message,
|
||||
sender_id=quote_abm.sender.user_id,
|
||||
sender_nickname=quote_abm.sender.nickname,
|
||||
time=quote_abm.timestamp,
|
||||
message_str=quote_abm.message_str,
|
||||
text=quote_abm.message_str,
|
||||
qq=sender_id,
|
||||
)
|
||||
abm.message.append(reply_component)
|
||||
|
||||
# 解析消息内容
|
||||
content_elements = await self.parse_satori_elements(content_for_parsing)
|
||||
abm.message.extend(content_elements)
|
||||
|
||||
# parse message_str
|
||||
abm.message_str = ""
|
||||
for comp in abm.message:
|
||||
for comp in content_elements:
|
||||
if isinstance(comp, Plain):
|
||||
abm.message_str += comp.text
|
||||
|
||||
@@ -333,6 +386,189 @@ class SatoriPlatformAdapter(Platform):
|
||||
logger.error(f"转换 Satori 消息失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_namespace_prefixes(self, content: str) -> set:
|
||||
"""提取XML内容中的命名空间前缀"""
|
||||
prefixes = set()
|
||||
|
||||
# 查找所有标签
|
||||
i = 0
|
||||
while i < len(content):
|
||||
# 查找开始标签
|
||||
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
|
||||
# 找到标签结束位置
|
||||
tag_end = content.find(">", i)
|
||||
if tag_end != -1:
|
||||
# 提取标签内容
|
||||
tag_content = content[i + 1 : tag_end]
|
||||
# 检查是否有命名空间前缀
|
||||
if ":" in tag_content and "xmlns:" not in tag_content:
|
||||
# 分割标签名
|
||||
parts = tag_content.split()
|
||||
if parts:
|
||||
tag_name = parts[0]
|
||||
if ":" in tag_name:
|
||||
prefix = tag_name.split(":")[0]
|
||||
# 确保是有效的命名空间前缀
|
||||
if (
|
||||
prefix.isalnum()
|
||||
or prefix.replace("_", "").isalnum()
|
||||
):
|
||||
prefixes.add(prefix)
|
||||
i = tag_end + 1
|
||||
else:
|
||||
i += 1
|
||||
# 查找结束标签
|
||||
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
|
||||
# 找到标签结束位置
|
||||
tag_end = content.find(">", i)
|
||||
if tag_end != -1:
|
||||
# 提取标签内容
|
||||
tag_content = content[i + 2 : tag_end]
|
||||
# 检查是否有命名空间前缀
|
||||
if ":" in tag_content:
|
||||
prefix = tag_content.split(":")[0]
|
||||
# 确保是有效的命名空间前缀
|
||||
if prefix.isalnum() or prefix.replace("_", "").isalnum():
|
||||
prefixes.add(prefix)
|
||||
i = tag_end + 1
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return prefixes
|
||||
|
||||
async def _extract_quote_element(self, content: str) -> Optional[dict]:
|
||||
"""提取<quote>标签信息"""
|
||||
try:
|
||||
# 处理命名空间前缀问题
|
||||
processed_content = content
|
||||
if ":" in content and not content.startswith("<root"):
|
||||
prefixes = self._extract_namespace_prefixes(content)
|
||||
|
||||
# 构建命名空间声明
|
||||
ns_declarations = " ".join(
|
||||
[
|
||||
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||
for prefix in prefixes
|
||||
]
|
||||
)
|
||||
|
||||
# 包装内容
|
||||
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||
elif not content.startswith("<root"):
|
||||
processed_content = f"<root>{content}</root>"
|
||||
else:
|
||||
processed_content = content
|
||||
|
||||
root = ET.fromstring(processed_content)
|
||||
|
||||
# 查找<quote>标签
|
||||
quote_element = None
|
||||
for elem in root.iter():
|
||||
tag_name = elem.tag
|
||||
if "}" in tag_name:
|
||||
tag_name = tag_name.split("}")[1]
|
||||
if tag_name.lower() == "quote":
|
||||
quote_element = elem
|
||||
break
|
||||
|
||||
if quote_element is not None:
|
||||
# 提取quote标签的属性
|
||||
quote_id = quote_element.get("id", "")
|
||||
|
||||
# 提取<quote>标签内部的内容
|
||||
inner_content = ""
|
||||
if quote_element.text:
|
||||
inner_content += quote_element.text
|
||||
for child in quote_element:
|
||||
inner_content += ET.tostring(
|
||||
child, encoding="unicode", method="xml"
|
||||
)
|
||||
if child.tail:
|
||||
inner_content += child.tail
|
||||
|
||||
# 构造移除了<quote>标签的内容
|
||||
content_without_quote = content.replace(
|
||||
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
|
||||
)
|
||||
|
||||
return {
|
||||
"quote": {"id": quote_id, "content": inner_content},
|
||||
"content_without_quote": content_without_quote,
|
||||
}
|
||||
|
||||
return None
|
||||
except ET.ParseError as e:
|
||||
logger.warning(f"XML解析失败,使用正则提取: {e}")
|
||||
return await self._extract_quote_with_regex(content)
|
||||
except Exception as e:
|
||||
logger.error(f"提取<quote>标签时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _extract_quote_with_regex(self, content: str) -> Optional[dict]:
|
||||
"""使用正则表达式提取quote标签信息"""
|
||||
import re
|
||||
|
||||
quote_pattern = r"<quote\s+([^>]*)>(.*?)</quote>"
|
||||
match = re.search(quote_pattern, content, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
attrs_str = match.group(1)
|
||||
inner_content = match.group(2)
|
||||
|
||||
id_match = re.search(r'id\s*=\s*["\']([^"\']*)["\']', attrs_str)
|
||||
quote_id = id_match.group(1) if id_match else ""
|
||||
content_without_quote = content.replace(match.group(0), "")
|
||||
content_without_quote = content_without_quote.strip()
|
||||
|
||||
return {
|
||||
"quote": {"id": quote_id, "content": inner_content},
|
||||
"content_without_quote": content_without_quote,
|
||||
}
|
||||
|
||||
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
||||
"""转换引用消息"""
|
||||
try:
|
||||
quote_abm = AstrBotMessage()
|
||||
quote_abm.message_id = quote.get("id", "")
|
||||
|
||||
# 解析引用消息的发送者
|
||||
quote_author = quote.get("author", {})
|
||||
if quote_author:
|
||||
quote_abm.sender = MessageMember(
|
||||
user_id=quote_author.get("id", ""),
|
||||
nickname=quote_author.get("nick", quote_author.get("name", "")),
|
||||
)
|
||||
else:
|
||||
# 如果没有作者信息,使用默认值
|
||||
quote_abm.sender = MessageMember(
|
||||
user_id=quote.get("user_id", ""),
|
||||
nickname="内容",
|
||||
)
|
||||
|
||||
# 解析引用消息内容
|
||||
quote_content = quote.get("content", "")
|
||||
quote_abm.message = await self.parse_satori_elements(quote_content)
|
||||
|
||||
quote_abm.message_str = ""
|
||||
for comp in quote_abm.message:
|
||||
if isinstance(comp, Plain):
|
||||
quote_abm.message_str += comp.text
|
||||
|
||||
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
|
||||
|
||||
# 如果没有任何内容,使用默认文本
|
||||
if not quote_abm.message_str.strip():
|
||||
quote_abm.message_str = "[引用消息]"
|
||||
|
||||
return quote_abm
|
||||
except Exception as e:
|
||||
logger.error(f"转换引用消息失败: {e}")
|
||||
return None
|
||||
|
||||
async def parse_satori_elements(self, content: str) -> list:
|
||||
"""解析 Satori 消息元素"""
|
||||
elements = []
|
||||
@@ -341,12 +577,35 @@ class SatoriPlatformAdapter(Platform):
|
||||
return elements
|
||||
|
||||
try:
|
||||
wrapped_content = f"<root>{content}</root>"
|
||||
root = ET.fromstring(wrapped_content)
|
||||
# 处理命名空间前缀问题
|
||||
processed_content = content
|
||||
if ":" in content and not content.startswith("<root"):
|
||||
prefixes = self._extract_namespace_prefixes(content)
|
||||
|
||||
# 构建命名空间声明
|
||||
ns_declarations = " ".join(
|
||||
[
|
||||
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||
for prefix in prefixes
|
||||
]
|
||||
)
|
||||
|
||||
# 包装内容
|
||||
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||
elif not content.startswith("<root"):
|
||||
processed_content = f"<root>{content}</root>"
|
||||
else:
|
||||
processed_content = content
|
||||
|
||||
root = ET.fromstring(processed_content)
|
||||
await self._parse_xml_node(root, elements)
|
||||
except ET.ParseError as e:
|
||||
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
|
||||
logger.warning(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
||||
# 如果解析失败,将整个内容当作纯文本
|
||||
if content.strip():
|
||||
elements.append(Plain(text=content))
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
|
||||
raise e
|
||||
|
||||
# 如果没有解析到任何元素,将整个内容当作纯文本
|
||||
@@ -361,7 +620,12 @@ class SatoriPlatformAdapter(Platform):
|
||||
elements.append(Plain(text=node.text))
|
||||
|
||||
for child in node:
|
||||
tag_name = child.tag.lower()
|
||||
# 获取标签名,去除命名空间前缀
|
||||
tag_name = child.tag
|
||||
if "}" in tag_name:
|
||||
tag_name = tag_name.split("}")[1]
|
||||
tag_name = tag_name.lower()
|
||||
|
||||
attrs = child.attrib
|
||||
|
||||
if tag_name == "at":
|
||||
@@ -372,31 +636,59 @@ class SatoriPlatformAdapter(Platform):
|
||||
src = attrs.get("src", "")
|
||||
if not src:
|
||||
continue
|
||||
if src.startswith("data:image/"):
|
||||
src = src.split(",")[1]
|
||||
elements.append(Image.fromBase64(src))
|
||||
elif src.startswith("http"):
|
||||
elements.append(Image.fromURL(src))
|
||||
else:
|
||||
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
|
||||
elements.append(Image(file=src))
|
||||
|
||||
elif tag_name == "file":
|
||||
src = attrs.get("src", "")
|
||||
name = attrs.get("name", "文件")
|
||||
if src:
|
||||
elements.append(File(file=src, name=name))
|
||||
elements.append(File(name=name, file=src))
|
||||
|
||||
elif tag_name in ("audio", "record"):
|
||||
src = attrs.get("src", "")
|
||||
if not src:
|
||||
continue
|
||||
if src.startswith("data:audio/"):
|
||||
src = src.split(",")[1]
|
||||
elements.append(Record.fromBase64(src))
|
||||
elif src.startswith("http"):
|
||||
elements.append(Record.fromURL(src))
|
||||
elements.append(Record(file=src))
|
||||
|
||||
elif tag_name == "quote":
|
||||
# quote标签已经被特殊处理
|
||||
pass
|
||||
|
||||
elif tag_name == "face":
|
||||
face_id = attrs.get("id", "")
|
||||
face_name = attrs.get("name", "")
|
||||
face_type = attrs.get("type", "")
|
||||
|
||||
if face_name:
|
||||
elements.append(Plain(text=f"[表情:{face_name}]"))
|
||||
elif face_id and face_type:
|
||||
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
|
||||
elif face_id:
|
||||
elements.append(Plain(text=f"[表情ID:{face_id}]"))
|
||||
else:
|
||||
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
|
||||
elements.append(Plain(text="[表情]"))
|
||||
|
||||
elif tag_name == "ark":
|
||||
# 作为纯文本添加到消息链中
|
||||
data = attrs.get("data", "")
|
||||
if data:
|
||||
import html
|
||||
|
||||
decoded_data = html.unescape(data)
|
||||
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||
else:
|
||||
elements.append(Plain(text="[ARK卡片]"))
|
||||
|
||||
elif tag_name == "json":
|
||||
# JSON标签 视为ARK卡片消息
|
||||
data = attrs.get("data", "")
|
||||
if data:
|
||||
import html
|
||||
|
||||
decoded_data = html.unescape(data)
|
||||
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||
else:
|
||||
elements.append(Plain(text="[JSON卡片]"))
|
||||
|
||||
else:
|
||||
# 未知标签,递归处理其内容
|
||||
|
||||
@@ -2,7 +2,18 @@ from typing import TYPE_CHECKING
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
Video,
|
||||
Reply,
|
||||
Forward,
|
||||
Node,
|
||||
Nodes,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .satori_adapter import SatoriPlatformAdapter
|
||||
@@ -17,6 +28,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
session_id: str,
|
||||
adapter: "SatoriPlatformAdapter",
|
||||
):
|
||||
# 更新平台元数据
|
||||
if adapter and hasattr(adapter, "logins") and adapter.logins:
|
||||
current_login = adapter.logins[0]
|
||||
platform_name = current_login.get("platform", "satori")
|
||||
user = current_login.get("user", {})
|
||||
user_id = user.get("id", "") if user else ""
|
||||
if not platform_meta.id and user_id:
|
||||
platform_meta.id = f"{platform_name}({user_id})"
|
||||
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.adapter = adapter
|
||||
self.platform = None
|
||||
@@ -39,44 +59,24 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
content_parts = []
|
||||
|
||||
for component in message.chain:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
content_parts.append(text)
|
||||
component_content = await cls._convert_component_to_satori_static(
|
||||
component
|
||||
)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
content_parts.append(f'<at id="{component.qq}"/>')
|
||||
elif component.name:
|
||||
content_parts.append(f'<at name="{component.name}"/>')
|
||||
# 特殊处理 Node 和 Nodes 组件
|
||||
if isinstance(component, Node):
|
||||
# 单个转发节点
|
||||
node_content = await cls._convert_node_to_satori_static(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
content_parts.append(
|
||||
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
content_parts.append(
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
content_parts.append(
|
||||
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
elif isinstance(component, Nodes):
|
||||
# 合并转发消息
|
||||
node_content = await cls._convert_nodes_to_satori_static(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
channel_id = session_id
|
||||
@@ -118,44 +118,22 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
content_parts = []
|
||||
|
||||
for component in message.chain:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
content_parts.append(text)
|
||||
component_content = await self._convert_component_to_satori(component)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
content_parts.append(f'<at id="{component.qq}"/>')
|
||||
elif component.name:
|
||||
content_parts.append(f'<at name="{component.name}"/>')
|
||||
# 特殊处理 Node 和 Nodes 组件
|
||||
if isinstance(component, Node):
|
||||
# 单个转发节点
|
||||
node_content = await self._convert_node_to_satori(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
content_parts.append(
|
||||
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
content_parts.append(
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
content_parts.append(
|
||||
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
elif isinstance(component, Nodes):
|
||||
# 合并转发消息
|
||||
node_content = await self._convert_nodes_to_satori(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
channel_id = self.session_id
|
||||
@@ -219,3 +197,227 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
logger.error(f"Satori 流式消息发送异常: {e}")
|
||||
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _convert_component_to_satori(self, component) -> str:
|
||||
"""将单个消息组件转换为 Satori 格式"""
|
||||
try:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
return text
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
return f'<at id="{component.qq}"/>'
|
||||
elif component.name:
|
||||
return f'<at name="{component.name}"/>'
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
return (
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
return f'<reply id="{component.id}"/>'
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
return f'<video src="{video_path_url}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
|
||||
elif isinstance(component, Forward):
|
||||
return f'<message id="{component.id}" forward/>'
|
||||
|
||||
# 对于其他未处理的组件类型,返回空字符串
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息组件失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _convert_node_to_satori(self, node: Node) -> str:
|
||||
"""将单个转发节点转换为 Satori 格式"""
|
||||
try:
|
||||
content_parts = []
|
||||
if node.content:
|
||||
for content_component in node.content:
|
||||
component_content = await self._convert_component_to_satori(
|
||||
content_component
|
||||
)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
|
||||
# 如果内容为空,添加默认内容
|
||||
if not content.strip():
|
||||
content = "[转发消息]"
|
||||
|
||||
# 构建 Satori 格式的转发节点
|
||||
author_attrs = []
|
||||
if node.uin:
|
||||
author_attrs.append(f'id="{node.uin}"')
|
||||
if node.name:
|
||||
author_attrs.append(f'name="{node.name}"')
|
||||
|
||||
author_attr_str = " ".join(author_attrs)
|
||||
|
||||
return f"<message><author {author_attr_str}/>{content}</message>"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换转发节点失败: {e}")
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def _convert_component_to_satori_static(cls, component) -> str:
|
||||
"""将单个消息组件转换为 Satori 格式"""
|
||||
try:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
return text
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
return f'<at id="{component.qq}"/>'
|
||||
elif component.name:
|
||||
return f'<at name="{component.name}"/>'
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
return (
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
return f'<reply id="{component.id}"/>'
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
return f'<video src="{video_path_url}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
|
||||
elif isinstance(component, Forward):
|
||||
return f'<message id="{component.id}" forward/>'
|
||||
|
||||
# 对于其他未处理的组件类型,返回空字符串
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息组件失败: {e}")
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def _convert_node_to_satori_static(cls, node: Node) -> str:
|
||||
"""将单个转发节点转换为 Satori 格式"""
|
||||
try:
|
||||
content_parts = []
|
||||
if node.content:
|
||||
for content_component in node.content:
|
||||
component_content = await cls._convert_component_to_satori_static(
|
||||
content_component
|
||||
)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
|
||||
# 如果内容为空,添加默认内容
|
||||
if not content.strip():
|
||||
content = "[转发消息]"
|
||||
|
||||
author_attrs = []
|
||||
if node.uin:
|
||||
author_attrs.append(f'id="{node.uin}"')
|
||||
if node.name:
|
||||
author_attrs.append(f'name="{node.name}"')
|
||||
|
||||
author_attr_str = " ".join(author_attrs)
|
||||
|
||||
return f"<message><author {author_attr_str}/>{content}</message>"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换转发节点失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _convert_nodes_to_satori(self, nodes: Nodes) -> str:
|
||||
"""将多个转发节点转换为 Satori 格式的合并转发"""
|
||||
try:
|
||||
node_parts = []
|
||||
|
||||
for node in nodes.nodes:
|
||||
node_content = await self._convert_node_to_satori(node)
|
||||
if node_content:
|
||||
node_parts.append(node_content)
|
||||
|
||||
if node_parts:
|
||||
return f"<message forward>{''.join(node_parts)}</message>"
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换合并转发消息失败: {e}")
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def _convert_nodes_to_satori_static(cls, nodes: Nodes) -> str:
|
||||
"""将多个转发节点转换为 Satori 格式的合并转发"""
|
||||
try:
|
||||
node_parts = []
|
||||
|
||||
for node in nodes.nodes:
|
||||
node_content = await cls._convert_node_to_satori_static(node)
|
||||
if node_content:
|
||||
node_parts.append(node_content)
|
||||
|
||||
if node_parts:
|
||||
return f"<message forward>{''.join(node_parts)}</message>"
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换合并转发消息失败: {e}")
|
||||
return ""
|
||||
|
||||
@@ -95,9 +95,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||
)
|
||||
id_ = self.config.get("id") or "telegram"
|
||||
return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
@@ -117,6 +116,10 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
if not self.application.updater:
|
||||
logger.error("Telegram Updater is not initialized. Cannot start polling.")
|
||||
return
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
await queue
|
||||
@@ -194,6 +197,11 @@ class TelegramPlatformAdapter(Platform):
|
||||
return cmd_name, description
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if not update.effective_chat:
|
||||
logger.warning(
|
||||
"Received a start command without an effective chat, skipping /start reply."
|
||||
)
|
||||
return
|
||||
await context.bot.send_message(
|
||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||
)
|
||||
@@ -206,15 +214,20 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
async def convert_message(
|
||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||
) -> AstrBotMessage:
|
||||
) -> AstrBotMessage | None:
|
||||
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||
|
||||
@param update: Telegram 的 Update 对象。
|
||||
@param context: Telegram 的 Context 对象。
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
if not update.message:
|
||||
logger.warning("Received an update without a message.")
|
||||
return None
|
||||
|
||||
message = AstrBotMessage()
|
||||
message.session_id = str(update.message.chat.id)
|
||||
|
||||
# 获得是群聊还是私聊
|
||||
if update.message.chat.type == ChatType.PRIVATE:
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -225,10 +238,13 @@ class TelegramPlatformAdapter(Platform):
|
||||
# Topic Group
|
||||
message.group_id += "#" + str(update.message.message_thread_id)
|
||||
message.session_id = message.group_id
|
||||
|
||||
message.message_id = str(update.message.message_id)
|
||||
_from_user = update.message.from_user
|
||||
if not _from_user:
|
||||
logger.warning("[Telegram] Received a message without a from_user.")
|
||||
return None
|
||||
message.sender = MessageMember(
|
||||
str(update.message.from_user.id), update.message.from_user.username
|
||||
str(_from_user.id), _from_user.username or "Unknown"
|
||||
)
|
||||
message.self_id = str(context.bot.username)
|
||||
message.raw_message = update
|
||||
@@ -247,22 +263,32 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
reply_abm = await self.convert_message(reply_update, context, False)
|
||||
|
||||
message.message.append(
|
||||
Comp.Reply(
|
||||
id=reply_abm.message_id,
|
||||
chain=reply_abm.message,
|
||||
sender_id=reply_abm.sender.user_id,
|
||||
sender_nickname=reply_abm.sender.nickname,
|
||||
time=reply_abm.timestamp,
|
||||
message_str=reply_abm.message_str,
|
||||
text=reply_abm.message_str,
|
||||
qq=reply_abm.sender.user_id,
|
||||
if reply_abm:
|
||||
message.message.append(
|
||||
Comp.Reply(
|
||||
id=reply_abm.message_id,
|
||||
chain=reply_abm.message,
|
||||
sender_id=reply_abm.sender.user_id,
|
||||
sender_nickname=reply_abm.sender.nickname,
|
||||
time=reply_abm.timestamp,
|
||||
message_str=reply_abm.message_str,
|
||||
text=reply_abm.message_str,
|
||||
qq=reply_abm.sender.user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if update.message.text:
|
||||
# 处理文本消息
|
||||
plain_text = update.message.text
|
||||
if (
|
||||
message.type == MessageType.GROUP_MESSAGE
|
||||
and update.message
|
||||
and update.message.reply_to_message
|
||||
and update.message.reply_to_message.from_user
|
||||
and update.message.reply_to_message.from_user.id == context.bot.id
|
||||
):
|
||||
plain_text2 = f"/@{context.bot.username} " + plain_text
|
||||
plain_text = plain_text2
|
||||
|
||||
# 群聊场景命令特殊处理
|
||||
if plain_text.startswith("/"):
|
||||
@@ -328,15 +354,25 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
elif update.message.document:
|
||||
file = await update.message.document.get_file()
|
||||
message.message = [
|
||||
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||
]
|
||||
file_name = update.message.document.file_name or uuid.uuid4().hex
|
||||
file_path = file.file_path
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}."
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
message.message = [
|
||||
Comp.Video(file=file.file_path, path=file.file_path),
|
||||
]
|
||||
file_name = update.message.video.file_name or uuid.uuid4().hex
|
||||
file_path = file.file_path
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"Telegram video file_path is None, cannot save the file {file_name}."
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.Video(file=file_path, path=file.file_path))
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -135,6 +136,39 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str | None, big: bool = False):
|
||||
"""
|
||||
给原消息添加 Telegram 反应:
|
||||
- 普通 emoji:传入 '👍'、'😂' 等
|
||||
- 自定义表情:传入其 custom_emoji_id(纯数字字符串)
|
||||
- 取消本机器人的反应:传入 None 或空字符串
|
||||
"""
|
||||
try:
|
||||
# 解析 chat_id(去掉超级群的 "#<thread_id>" 片段)
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
chat_id = (self.message_obj.group_id or "").split("#")[0]
|
||||
else:
|
||||
chat_id = self.get_sender_id()
|
||||
|
||||
message_id = int(self.message_obj.message_id)
|
||||
|
||||
# 组装 reaction 参数(必须是 ReactionType 的列表)
|
||||
if not emoji: # 清空本 bot 的反应
|
||||
reaction_param = [] # 空列表表示移除本 bot 的反应
|
||||
elif emoji.isdigit(): # 自定义表情:传 custom_emoji_id
|
||||
reaction_param = [ReactionTypeCustomEmoji(emoji)]
|
||||
else: # 普通 emoji
|
||||
reaction_param = [ReactionTypeEmoji(emoji)]
|
||||
|
||||
await self.client.set_message_reaction(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
reaction=reaction_param, # 注意是列表
|
||||
is_big=big, # 可选:大动画
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] 添加反应失败: {e}")
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
message_thread_id = None
|
||||
|
||||
@@ -218,7 +252,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
delta = ""
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
|
||||
@@ -91,7 +91,6 @@ class WebChatAdapter(Platform):
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = "webchat"
|
||||
abm.tag = "webchat"
|
||||
abm.sender = MessageMember(username, username)
|
||||
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
id=self.config.get("id", "wecom"),
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
289
astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py
Normal file
289
astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding:utf-8 -*-
|
||||
|
||||
"""对企业微信发送给企业后台的消息加解密示例代码.
|
||||
@copyright: Copyright (c) 1998-2020 Tencent Inc.
|
||||
|
||||
"""
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import random
|
||||
import hashlib
|
||||
import time
|
||||
import struct
|
||||
from Crypto.Cipher import AES
|
||||
import socket
|
||||
import json
|
||||
|
||||
from . import ierror
|
||||
|
||||
"""
|
||||
关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案
|
||||
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
|
||||
下载后,按照README中的“Installation”小节的提示进行pycrypto安装。
|
||||
"""
|
||||
|
||||
|
||||
class FormatException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def throw_exception(message, exception_class=FormatException):
|
||||
"""my define raise exception function"""
|
||||
raise exception_class(message)
|
||||
|
||||
|
||||
class SHA1:
|
||||
"""计算企业微信的消息签名接口"""
|
||||
|
||||
def getSHA1(self, token, timestamp, nonce, encrypt):
|
||||
"""用SHA1算法生成安全签名
|
||||
@param token: 票据
|
||||
@param timestamp: 时间戳
|
||||
@param encrypt: 密文
|
||||
@param nonce: 随机字符串
|
||||
@return: 安全签名
|
||||
"""
|
||||
try:
|
||||
# 确保所有输入都是字符串类型
|
||||
if isinstance(encrypt, bytes):
|
||||
encrypt = encrypt.decode("utf-8")
|
||||
|
||||
sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)]
|
||||
sortlist.sort()
|
||||
sha = hashlib.sha1()
|
||||
sha.update("".join(sortlist).encode("utf-8"))
|
||||
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
|
||||
|
||||
|
||||
class JsonParse:
|
||||
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
|
||||
|
||||
# json消息模板
|
||||
AES_TEXT_RESPONSE_TEMPLATE = """{
|
||||
"encrypt": "%(msg_encrypt)s",
|
||||
"msgsignature": "%(msg_signaturet)s",
|
||||
"timestamp": "%(timestamp)s",
|
||||
"nonce": "%(nonce)s"
|
||||
}"""
|
||||
|
||||
def extract(self, jsontext):
|
||||
"""提取出json数据包中的加密消息
|
||||
@param jsontext: 待提取的json字符串
|
||||
@return: 提取出的加密消息字符串
|
||||
"""
|
||||
try:
|
||||
json_dict = json.loads(jsontext)
|
||||
return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_ParseJson_Error, None
|
||||
|
||||
def generate(self, encrypt, signature, timestamp, nonce):
|
||||
"""生成json消息
|
||||
@param encrypt: 加密后的消息密文
|
||||
@param signature: 安全签名
|
||||
@param timestamp: 时间戳
|
||||
@param nonce: 随机字符串
|
||||
@return: 生成的json字符串
|
||||
"""
|
||||
resp_dict = {
|
||||
"msg_encrypt": encrypt,
|
||||
"msg_signaturet": signature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
}
|
||||
resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
|
||||
return resp_json
|
||||
|
||||
|
||||
class PKCS7Encoder:
|
||||
"""提供基于PKCS7算法的加解密接口"""
|
||||
|
||||
block_size = 32
|
||||
|
||||
def encode(self, text):
|
||||
"""对需要加密的明文进行填充补位
|
||||
@param text: 需要进行填充补位操作的明文(bytes类型)
|
||||
@return: 补齐明文字符串(bytes类型)
|
||||
"""
|
||||
text_length = len(text)
|
||||
# 计算需要填充的位数
|
||||
amount_to_pad = self.block_size - (text_length % self.block_size)
|
||||
if amount_to_pad == 0:
|
||||
amount_to_pad = self.block_size
|
||||
# 获得补位所用的字符
|
||||
pad = bytes([amount_to_pad])
|
||||
# 确保text是bytes类型
|
||||
if isinstance(text, str):
|
||||
text = text.encode("utf-8")
|
||||
return text + pad * amount_to_pad
|
||||
|
||||
def decode(self, decrypted):
|
||||
"""删除解密后明文的补位字符
|
||||
@param decrypted: 解密后的明文
|
||||
@return: 删除补位字符后的明文
|
||||
"""
|
||||
pad = ord(decrypted[-1])
|
||||
if pad < 1 or pad > 32:
|
||||
pad = 0
|
||||
return decrypted[:-pad]
|
||||
|
||||
|
||||
class Prpcrypt(object):
|
||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||
|
||||
def __init__(self, key):
|
||||
# self.key = base64.b64decode(key+"=")
|
||||
self.key = key
|
||||
# 设置加解密模式为AES的CBC模式
|
||||
self.mode = AES.MODE_CBC
|
||||
|
||||
def encrypt(self, text, receiveid):
|
||||
"""对明文进行加密
|
||||
@param text: 需要加密的明文
|
||||
@return: 加密得到的字符串
|
||||
"""
|
||||
# 16位随机字符串添加到明文开头
|
||||
text = text.encode()
|
||||
text = (
|
||||
self.get_random_str()
|
||||
+ struct.pack("I", socket.htonl(len(text)))
|
||||
+ text
|
||||
+ receiveid.encode()
|
||||
)
|
||||
|
||||
# 使用自定义的填充方式对明文进行补位填充
|
||||
pkcs7 = PKCS7Encoder()
|
||||
text = pkcs7.encode(text)
|
||||
# 加密
|
||||
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
|
||||
try:
|
||||
ciphertext = cryptor.encrypt(text)
|
||||
# 使用BASE64对加密后的字符串进行编码
|
||||
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
|
||||
except Exception as e:
|
||||
logger = logging.getLogger("astrbot")
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
|
||||
|
||||
def decrypt(self, text, receiveid):
|
||||
"""对解密后的明文进行补位删除
|
||||
@param text: 密文
|
||||
@return: 删除填充补位后的明文
|
||||
"""
|
||||
try:
|
||||
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
|
||||
# 使用BASE64对密文进行解码,然后AES-CBC解密
|
||||
plain_text = cryptor.decrypt(base64.b64decode(text))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
|
||||
try:
|
||||
pad = plain_text[-1]
|
||||
# 去掉补位字符串
|
||||
# pkcs7 = PKCS7Encoder()
|
||||
# plain_text = pkcs7.encode(plain_text)
|
||||
# 去除16位随机字符串
|
||||
content = plain_text[16:-pad]
|
||||
json_len = socket.ntohl(struct.unpack("I", content[:4])[0])
|
||||
json_content = content[4 : json_len + 4].decode("utf-8")
|
||||
from_receiveid = content[json_len + 4 :].decode("utf-8")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_IllegalBuffer, None
|
||||
if from_receiveid != receiveid:
|
||||
print("receiveid not match", receiveid, from_receiveid)
|
||||
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
|
||||
return 0, json_content
|
||||
|
||||
def get_random_str(self):
|
||||
"""随机生成16位字符串
|
||||
@return: 16位字符串
|
||||
"""
|
||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||
|
||||
|
||||
class WXBizJsonMsgCrypt(object):
|
||||
# 构造函数
|
||||
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
||||
try:
|
||||
self.key = base64.b64decode(sEncodingAESKey + "=")
|
||||
assert len(self.key) == 32
|
||||
except Exception as e:
|
||||
throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException)
|
||||
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
||||
self.m_sToken = sToken
|
||||
self.m_sReceiveId = sReceiveId
|
||||
|
||||
# 验证URL
|
||||
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||
# @param sNonce: 随机串,对应URL参数的nonce
|
||||
# @param sEchoStr: 随机串,对应URL参数的echostr
|
||||
# @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效
|
||||
# @return:成功0,失败返回对应的错误码
|
||||
|
||||
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if not signature == sMsgSignature:
|
||||
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
|
||||
return ret, sReplyEchoStr
|
||||
|
||||
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
|
||||
# 将企业回复用户的消息加密打包
|
||||
# @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串
|
||||
# @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间
|
||||
# @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce
|
||||
# sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串,
|
||||
# return:成功0,sEncryptMsg,失败返回对应的错误码None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
|
||||
encrypt = encrypt.decode("utf-8") # type: ignore
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if timestamp is None:
|
||||
timestamp = str(int(time.time()))
|
||||
# 生成安全签名
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
jsonParse = JsonParse()
|
||||
return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce)
|
||||
|
||||
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
|
||||
# 检验消息的真实性,并且获取解密后的明文
|
||||
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||
# @param sNonce: 随机串,对应URL参数的nonce
|
||||
# @param sPostData: 密文,对应POST请求的数据
|
||||
# json_content: 解密后的原文,当return返回0时有效
|
||||
# @return: 成功0,失败返回对应的错误码
|
||||
# 验证安全签名
|
||||
jsonParse = JsonParse()
|
||||
ret, encrypt = jsonParse.extract(sPostData)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if not signature == sMsgSignature:
|
||||
print("signature not match")
|
||||
print(signature)
|
||||
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId)
|
||||
return ret, json_content
|
||||
17
astrbot/core/platform/sources/wecom_ai_bot/__init__.py
Normal file
17
astrbot/core/platform/sources/wecom_ai_bot/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
企业微信智能机器人平台适配器包
|
||||
"""
|
||||
|
||||
from .wecomai_adapter import WecomAIBotAdapter
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_event import WecomAIBotMessageEvent
|
||||
from .wecomai_server import WecomAIBotServer
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
|
||||
__all__ = [
|
||||
"WecomAIBotAdapter",
|
||||
"WecomAIBotAPIClient",
|
||||
"WecomAIBotMessageEvent",
|
||||
"WecomAIBotServer",
|
||||
"WecomAIBotConstants",
|
||||
]
|
||||
20
astrbot/core/platform/sources/wecom_ai_bot/ierror.py
Normal file
20
astrbot/core/platform/sources/wecom_ai_bot/ierror.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#########################################################################
|
||||
# Author: jonyqin
|
||||
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
|
||||
# File Name: ierror.py
|
||||
# Description:定义错误码含义
|
||||
#########################################################################
|
||||
WXBizMsgCrypt_OK = 0
|
||||
WXBizMsgCrypt_ValidateSignature_Error = -40001
|
||||
WXBizMsgCrypt_ParseJson_Error = -40002
|
||||
WXBizMsgCrypt_ComputeSignature_Error = -40003
|
||||
WXBizMsgCrypt_IllegalAesKey = -40004
|
||||
WXBizMsgCrypt_ValidateCorpid_Error = -40005
|
||||
WXBizMsgCrypt_EncryptAES_Error = -40006
|
||||
WXBizMsgCrypt_DecryptAES_Error = -40007
|
||||
WXBizMsgCrypt_IllegalBuffer = -40008
|
||||
WXBizMsgCrypt_EncodeBase64_Error = -40009
|
||||
WXBizMsgCrypt_DecodeBase64_Error = -40010
|
||||
WXBizMsgCrypt_GenReturnJson_Error = -40011
|
||||
445
astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py
Normal file
445
astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
企业微信智能机器人平台适配器
|
||||
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
|
||||
参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Awaitable, Any, Dict, Optional, Callable
|
||||
|
||||
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, At, Image
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
from .wecomai_api import (
|
||||
WecomAIBotAPIClient,
|
||||
WecomAIBotMessageParser,
|
||||
WecomAIBotStreamMessageBuilder,
|
||||
)
|
||||
from .wecomai_event import WecomAIBotMessageEvent
|
||||
from .wecomai_server import WecomAIBotServer
|
||||
from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr
|
||||
from .wecomai_utils import (
|
||||
WecomAIBotConstants,
|
||||
format_session_id,
|
||||
generate_random_string,
|
||||
process_encrypted_image,
|
||||
)
|
||||
|
||||
|
||||
class WecomAIQueueListener:
|
||||
"""企业微信智能机器人队列监听器,参考webchat的QueueListener设计"""
|
||||
|
||||
def __init__(
|
||||
self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]]
|
||||
) -> None:
|
||||
self.queue_mgr = queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, session_id: str):
|
||||
"""监听特定会话的队列"""
|
||||
queue = self.queue_mgr.get_or_create_queue(session_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""监控新会话队列并启动监听器"""
|
||||
monitored_sessions = set()
|
||||
|
||||
while True:
|
||||
# 检查新会话
|
||||
current_sessions = set(self.queue_mgr.queues.keys())
|
||||
new_sessions = current_sessions - monitored_sessions
|
||||
|
||||
# 为新会话启动监听器
|
||||
for session_id in new_sessions:
|
||||
task = asyncio.create_task(self.listen_to_queue(session_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_sessions.add(session_id)
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
# 清理已不存在的会话
|
||||
removed_sessions = monitored_sessions - current_sessions
|
||||
monitored_sessions -= removed_sessions
|
||||
|
||||
# 清理过期的待处理响应
|
||||
self.queue_mgr.cleanup_expired_responses()
|
||||
|
||||
await asyncio.sleep(1) # 每秒检查一次新会话
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息"
|
||||
)
|
||||
class WecomAIBotAdapter(Platform):
|
||||
"""企业微信智能机器人适配器"""
|
||||
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
|
||||
# 初始化配置参数
|
||||
self.token = self.config["token"]
|
||||
self.encoding_aes_key = self.config["encoding_aes_key"]
|
||||
self.port = int(self.config["port"])
|
||||
self.host = self.config.get("callback_server_host", "0.0.0.0")
|
||||
self.bot_name = self.config.get("wecom_ai_bot_name", "")
|
||||
self.initial_respond_text = self.config.get(
|
||||
"wecomaibot_init_respond_text", "💭 思考中..."
|
||||
)
|
||||
self.friend_message_welcome_text = self.config.get(
|
||||
"wecomaibot_friend_message_welcome_text", ""
|
||||
)
|
||||
|
||||
# 平台元数据
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wecom_ai_bot",
|
||||
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
|
||||
id=self.config.get("id", "wecom_ai_bot"),
|
||||
)
|
||||
|
||||
# 初始化 API 客户端
|
||||
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
|
||||
|
||||
# 初始化 HTTP 服务器
|
||||
self.server = WecomAIBotServer(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
api_client=self.api_client,
|
||||
message_handler=self._process_message,
|
||||
)
|
||||
|
||||
# 事件循环和关闭信号
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
# 队列监听器
|
||||
self.queue_listener = WecomAIQueueListener(
|
||||
wecomai_queue_mgr, self._handle_queued_message
|
||||
)
|
||||
|
||||
async def _handle_queued_message(self, data: dict):
|
||||
"""处理队列中的消息,类似webchat的callback"""
|
||||
try:
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
except Exception as e:
|
||||
logger.error(f"处理队列消息时发生异常: {e}")
|
||||
|
||||
async def _process_message(
|
||||
self, message_data: Dict[str, Any], callback_params: Dict[str, str]
|
||||
) -> Optional[str]:
|
||||
"""处理接收到的消息
|
||||
|
||||
Args:
|
||||
message_data: 解密后的消息数据
|
||||
callback_params: 回调参数 (nonce, timestamp)
|
||||
|
||||
Returns:
|
||||
加密后的响应消息,无需响应时返回 None
|
||||
"""
|
||||
msgtype = message_data.get("msgtype")
|
||||
if not msgtype:
|
||||
logger.warning(f"消息类型未知,忽略: {message_data}")
|
||||
return None
|
||||
session_id = self._extract_session_id(message_data)
|
||||
if msgtype in ("text", "image", "mixed"):
|
||||
# user sent a text / image / mixed message
|
||||
try:
|
||||
# create a brand-new unique stream_id for this message session
|
||||
stream_id = f"{session_id}_{generate_random_string(10)}"
|
||||
await self._enqueue_message(
|
||||
message_data, callback_params, stream_id, session_id
|
||||
)
|
||||
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
|
||||
|
||||
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
stream_id, self.initial_respond_text, False
|
||||
)
|
||||
return await self.api_client.encrypt_message(
|
||||
resp, callback_params["nonce"], callback_params["timestamp"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("处理消息时发生异常: %s", e)
|
||||
return None
|
||||
elif msgtype == "stream":
|
||||
# wechat server is requesting for updates of a stream
|
||||
stream_id = message_data["stream"]["id"]
|
||||
if not wecomai_queue_mgr.has_back_queue(stream_id):
|
||||
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
|
||||
# 返回结束标志,告诉微信服务器流已结束
|
||||
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
stream_id, "", True
|
||||
)
|
||||
resp = await self.api_client.encrypt_message(
|
||||
end_message,
|
||||
callback_params["nonce"],
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
return resp
|
||||
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
if queue.empty():
|
||||
logger.debug(
|
||||
f"No new messages in back queue for stream_id: {stream_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# aggregate all delta chains in the back queue
|
||||
latest_plain_content = ""
|
||||
image_base64 = []
|
||||
finish = False
|
||||
while not queue.empty():
|
||||
msg = await queue.get()
|
||||
if msg["type"] == "plain":
|
||||
latest_plain_content = msg["data"] or ""
|
||||
elif msg["type"] == "image":
|
||||
image_base64.append(msg["image_data"])
|
||||
elif msg["type"] == "end":
|
||||
# stream end
|
||||
finish = True
|
||||
wecomai_queue_mgr.remove_queues(stream_id)
|
||||
break
|
||||
else:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}"
|
||||
)
|
||||
if latest_plain_content or image_base64:
|
||||
msg_items = []
|
||||
if finish and image_base64:
|
||||
for img_b64 in image_base64:
|
||||
# get md5 of image
|
||||
img_data = base64.b64decode(img_b64)
|
||||
img_md5 = hashlib.md5(img_data).hexdigest()
|
||||
msg_items.append(
|
||||
{
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
|
||||
"image": {"base64": img_b64, "md5": img_md5},
|
||||
}
|
||||
)
|
||||
image_base64 = []
|
||||
|
||||
plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream(
|
||||
stream_id, latest_plain_content, msg_items, finish
|
||||
)
|
||||
encrypted_message = await self.api_client.encrypt_message(
|
||||
plain_message,
|
||||
callback_params["nonce"],
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
if encrypted_message:
|
||||
logger.debug(
|
||||
f"Stream message sent successfully, stream_id: {stream_id}"
|
||||
)
|
||||
else:
|
||||
logger.error("消息加密失败")
|
||||
return encrypted_message
|
||||
return None
|
||||
elif msgtype == "event":
|
||||
event = message_data.get("event")
|
||||
if event == "enter_chat" and self.friend_message_welcome_text:
|
||||
# 用户进入会话,发送欢迎消息
|
||||
try:
|
||||
resp = WecomAIBotStreamMessageBuilder.make_text(
|
||||
self.friend_message_welcome_text
|
||||
)
|
||||
return await self.api_client.encrypt_message(
|
||||
resp,
|
||||
callback_params["nonce"],
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("处理欢迎消息时发生异常: %s", e)
|
||||
return None
|
||||
pass
|
||||
|
||||
def _extract_session_id(self, message_data: Dict[str, Any]) -> str:
|
||||
"""从消息数据中提取会话ID"""
|
||||
user_id = message_data.get("from", {}).get("userid", "default_user")
|
||||
return format_session_id("wecomai", user_id)
|
||||
|
||||
async def _enqueue_message(
|
||||
self,
|
||||
message_data: Dict[str, Any],
|
||||
callback_params: Dict[str, str],
|
||||
stream_id: str,
|
||||
session_id: str,
|
||||
):
|
||||
"""将消息放入队列进行异步处理"""
|
||||
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
|
||||
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
message_payload = {
|
||||
"message_data": message_data,
|
||||
"callback_params": callback_params,
|
||||
"session_id": session_id,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
await input_queue.put(message_payload)
|
||||
logger.debug(f"[WecomAI] 消息已入队: {stream_id}")
|
||||
|
||||
async def convert_message(self, payload: dict) -> AstrBotMessage:
|
||||
"""转换队列中的消息数据为AstrBotMessage,类似webchat的convert_message"""
|
||||
message_data = payload["message_data"]
|
||||
session_id = payload["session_id"]
|
||||
# callback_params = payload["callback_params"] # 保留但暂时不使用
|
||||
|
||||
# 解析消息内容
|
||||
msgtype = message_data.get("msgtype")
|
||||
content = ""
|
||||
image_base64 = []
|
||||
|
||||
_img_url_to_process = []
|
||||
msg_items = []
|
||||
|
||||
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||
content = WecomAIBotMessageParser.parse_text_message(message_data)
|
||||
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||
_img_url_to_process.append(
|
||||
WecomAIBotMessageParser.parse_image_message(message_data)
|
||||
)
|
||||
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
|
||||
# 提取混合消息中的文本内容
|
||||
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
|
||||
text_parts = []
|
||||
for item in msg_items or []:
|
||||
if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||
text_content = item.get("text", {}).get("content", "")
|
||||
if text_content:
|
||||
text_parts.append(text_content)
|
||||
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||
image_url = item.get("image", {}).get("url", "")
|
||||
if image_url:
|
||||
_img_url_to_process.append(image_url)
|
||||
content = " ".join(text_parts) if text_parts else ""
|
||||
else:
|
||||
content = f"[{msgtype}消息]"
|
||||
|
||||
# 并行处理图片下载和解密
|
||||
if _img_url_to_process:
|
||||
tasks = [
|
||||
process_encrypted_image(url, self.encoding_aes_key)
|
||||
for url in _img_url_to_process
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
for success, result in results:
|
||||
if success:
|
||||
image_base64.append(result)
|
||||
else:
|
||||
logger.error(f"处理加密图片失败: {result}")
|
||||
|
||||
# 构建 AstrBotMessage
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_name
|
||||
abm.message_str = content or "[未知消息]"
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.timestamp = int(time.time())
|
||||
abm.raw_message = payload
|
||||
|
||||
# 发送者信息
|
||||
abm.sender = MessageMember(
|
||||
user_id=message_data.get("from", {}).get("userid", "unknown"),
|
||||
nickname=message_data.get("from", {}).get("userid", "unknown"),
|
||||
)
|
||||
|
||||
# 消息类型
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if message_data.get("chattype") == "group"
|
||||
else MessageType.FRIEND_MESSAGE
|
||||
)
|
||||
abm.session_id = session_id
|
||||
|
||||
# 消息内容
|
||||
abm.message = []
|
||||
|
||||
# 处理 At
|
||||
if self.bot_name and f"@{self.bot_name}" in abm.message_str:
|
||||
abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip()
|
||||
abm.message.append(At(qq=self.bot_name, name=self.bot_name))
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
if image_base64:
|
||||
for img_b64 in image_base64:
|
||||
abm.message.append(Image.fromBase64(img_b64))
|
||||
|
||||
logger.debug(f"WecomAIAdapter: {abm.message}")
|
||||
return abm
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
# 企业微信智能机器人主要通过回调响应,这里记录日志
|
||||
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
||||
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
|
||||
|
||||
async def run_both():
|
||||
# 同时运行HTTP服务器和队列监听器
|
||||
await asyncio.gather(
|
||||
self.server.start_server(),
|
||||
self.queue_listener.run(),
|
||||
)
|
||||
|
||||
return run_both()
|
||||
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("企业微信智能机器人适配器正在关闭...")
|
||||
self.shutdown_event.set()
|
||||
await self.server.shutdown()
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""获取平台元数据"""
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
"""处理消息,创建消息事件并提交到事件队列"""
|
||||
try:
|
||||
message_event = WecomAIBotMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
api_client=self.api_client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("处理消息时发生异常: %s", e)
|
||||
|
||||
def get_client(self) -> WecomAIBotAPIClient:
|
||||
"""获取 API 客户端"""
|
||||
return self.api_client
|
||||
|
||||
def get_server(self) -> WecomAIBotServer:
|
||||
"""获取 HTTP 服务器实例"""
|
||||
return self.server
|
||||
378
astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py
Normal file
378
astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
企业微信智能机器人 API 客户端
|
||||
处理消息加密解密、API 调用等
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import Dict, Any, Optional, Tuple, Union
|
||||
from Crypto.Cipher import AES
|
||||
import aiohttp
|
||||
|
||||
from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class WecomAIBotAPIClient:
|
||||
"""企业微信智能机器人 API 客户端"""
|
||||
|
||||
def __init__(self, token: str, encoding_aes_key: str):
|
||||
"""初始化 API 客户端
|
||||
|
||||
Args:
|
||||
token: 企业微信机器人 Token
|
||||
encoding_aes_key: 消息加密密钥
|
||||
"""
|
||||
self.token = token
|
||||
self.encoding_aes_key = encoding_aes_key
|
||||
self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串
|
||||
|
||||
async def decrypt_message(
|
||||
self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str
|
||||
) -> Tuple[int, Optional[Dict[str, Any]]]:
|
||||
"""解密企业微信消息
|
||||
|
||||
Args:
|
||||
encrypted_data: 加密的消息数据
|
||||
msg_signature: 消息签名
|
||||
timestamp: 时间戳
|
||||
nonce: 随机数
|
||||
|
||||
Returns:
|
||||
(错误码, 解密后的消息数据字典)
|
||||
"""
|
||||
try:
|
||||
ret, decrypted_msg = self.wxcpt.DecryptMsg(
|
||||
encrypted_data, msg_signature, timestamp, nonce
|
||||
)
|
||||
|
||||
if ret != WecomAIBotConstants.SUCCESS:
|
||||
logger.error(f"消息解密失败,错误码: {ret}")
|
||||
return ret, None
|
||||
|
||||
# 解析 JSON
|
||||
if decrypted_msg:
|
||||
try:
|
||||
message_data = json.loads(decrypted_msg)
|
||||
logger.debug(f"解密成功,消息内容: {message_data}")
|
||||
return WecomAIBotConstants.SUCCESS, message_data
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}")
|
||||
return WecomAIBotConstants.PARSE_XML_ERROR, None
|
||||
else:
|
||||
logger.error("解密消息为空")
|
||||
return WecomAIBotConstants.DECRYPT_ERROR, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密过程发生异常: {e}")
|
||||
return WecomAIBotConstants.DECRYPT_ERROR, None
|
||||
|
||||
async def encrypt_message(
|
||||
self, plain_message: str, nonce: str, timestamp: str
|
||||
) -> Optional[str]:
|
||||
"""加密消息
|
||||
|
||||
Args:
|
||||
plain_message: 明文消息
|
||||
nonce: 随机数
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
加密后的消息,失败时返回 None
|
||||
"""
|
||||
try:
|
||||
ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp)
|
||||
|
||||
if ret != WecomAIBotConstants.SUCCESS:
|
||||
logger.error(f"消息加密失败,错误码: {ret}")
|
||||
return None
|
||||
|
||||
logger.debug("消息加密成功")
|
||||
return encrypted_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密过程发生异常: {e}")
|
||||
return None
|
||||
|
||||
def verify_url(
|
||||
self, msg_signature: str, timestamp: str, nonce: str, echostr: str
|
||||
) -> str:
|
||||
"""验证回调 URL
|
||||
|
||||
Args:
|
||||
msg_signature: 消息签名
|
||||
timestamp: 时间戳
|
||||
nonce: 随机数
|
||||
echostr: 验证字符串
|
||||
|
||||
Returns:
|
||||
验证结果字符串
|
||||
"""
|
||||
try:
|
||||
ret, echo_result = self.wxcpt.VerifyURL(
|
||||
msg_signature, timestamp, nonce, echostr
|
||||
)
|
||||
|
||||
if ret != WecomAIBotConstants.SUCCESS:
|
||||
logger.error(f"URL 验证失败,错误码: {ret}")
|
||||
return "verify fail"
|
||||
|
||||
logger.info("URL 验证成功")
|
||||
return echo_result if echo_result else "verify fail"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"URL 验证发生异常: {e}")
|
||||
return "verify fail"
|
||||
|
||||
async def process_encrypted_image(
|
||||
self, image_url: str, aes_key_base64: Optional[str] = None
|
||||
) -> Tuple[bool, Union[bytes, str]]:
|
||||
"""下载并解密加密图片
|
||||
|
||||
Args:
|
||||
image_url: 加密图片的 URL
|
||||
aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥
|
||||
|
||||
Returns:
|
||||
(是否成功, 图片数据或错误信息)
|
||||
"""
|
||||
try:
|
||||
# 下载图片
|
||||
logger.info(f"开始下载加密图片: {image_url}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=15) as response:
|
||||
if response.status != 200:
|
||||
error_msg = f"图片下载失败,状态码: {response.status}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
encrypted_data = await response.read()
|
||||
logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节")
|
||||
|
||||
# 准备解密密钥
|
||||
if aes_key_base64 is None:
|
||||
aes_key_base64 = self.encoding_aes_key
|
||||
|
||||
if not aes_key_base64:
|
||||
raise ValueError("AES 密钥不能为空")
|
||||
|
||||
# Base64 解码密钥
|
||||
aes_key = base64.b64decode(
|
||||
aes_key_base64 + "=" * (-len(aes_key_base64) % 4)
|
||||
)
|
||||
if len(aes_key) != 32:
|
||||
raise ValueError("无效的 AES 密钥长度: 应为 32 字节")
|
||||
|
||||
iv = aes_key[:16] # 初始向量为密钥前 16 字节
|
||||
|
||||
# 解密图片数据
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted_data = cipher.decrypt(encrypted_data)
|
||||
|
||||
# 去除 PKCS#7 填充
|
||||
pad_len = decrypted_data[-1]
|
||||
if pad_len > 32: # AES-256 块大小为 32 字节
|
||||
raise ValueError("无效的填充长度 (大于32字节)")
|
||||
|
||||
decrypted_data = decrypted_data[:-pad_len]
|
||||
logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节")
|
||||
|
||||
return True, decrypted_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
error_msg = f"图片下载失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"参数错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片处理异常: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
class WecomAIBotStreamMessageBuilder:
|
||||
"""企业微信智能机器人流消息构建器"""
|
||||
|
||||
@staticmethod
|
||||
def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str:
|
||||
"""构建文本流消息
|
||||
|
||||
Args:
|
||||
stream_id: 流 ID
|
||||
content: 文本内容
|
||||
finish: 是否结束
|
||||
|
||||
Returns:
|
||||
JSON 格式的流消息字符串
|
||||
"""
|
||||
plain = {
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||
"stream": {"id": stream_id, "finish": finish, "content": content},
|
||||
}
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def make_image_stream(
|
||||
stream_id: str, image_data: bytes, finish: bool = False
|
||||
) -> str:
|
||||
"""构建图片流消息
|
||||
|
||||
Args:
|
||||
stream_id: 流 ID
|
||||
image_data: 图片二进制数据
|
||||
finish: 是否结束
|
||||
|
||||
Returns:
|
||||
JSON 格式的流消息字符串
|
||||
"""
|
||||
image_md5 = hashlib.md5(image_data).hexdigest()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
plain = {
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||
"stream": {
|
||||
"id": stream_id,
|
||||
"finish": finish,
|
||||
"msg_item": [
|
||||
{
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
|
||||
"image": {"base64": image_base64, "md5": image_md5},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def make_mixed_stream(
|
||||
stream_id: str, content: str, msg_items: list, finish: bool = False
|
||||
) -> str:
|
||||
"""构建混合类型流消息
|
||||
|
||||
Args:
|
||||
stream_id: 流 ID
|
||||
content: 文本内容
|
||||
msg_items: 消息项列表
|
||||
finish: 是否结束
|
||||
|
||||
Returns:
|
||||
JSON 格式的流消息字符串
|
||||
"""
|
||||
plain = {
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||
"stream": {"id": stream_id, "finish": finish, "msg_item": msg_items},
|
||||
}
|
||||
if content:
|
||||
plain["stream"]["content"] = content
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def make_text(content: str) -> str:
|
||||
"""构建文本消息
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
|
||||
Returns:
|
||||
JSON 格式的文本消息字符串
|
||||
"""
|
||||
plain = {"msgtype": "text", "text": {"content": content}}
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
|
||||
class WecomAIBotMessageParser:
|
||||
"""企业微信智能机器人消息解析器"""
|
||||
|
||||
@staticmethod
|
||||
def parse_text_message(data: Dict[str, Any]) -> Optional[str]:
|
||||
"""解析文本消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
文本内容,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("text", {}).get("content")
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("文本消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_image_message(data: Dict[str, Any]) -> Optional[str]:
|
||||
"""解析图片消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
图片 URL,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("image", {}).get("url")
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("图片消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""解析流消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
流消息数据,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
stream_data = data.get("stream", {})
|
||||
return {
|
||||
"id": stream_data.get("id"),
|
||||
"finish": stream_data.get("finish"),
|
||||
"content": stream_data.get("content"),
|
||||
"msg_item": stream_data.get("msg_item", []),
|
||||
}
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("流消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]:
|
||||
"""解析混合消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
消息项列表,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("mixed", {}).get("msg_item", [])
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("混合消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""解析事件消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
事件数据,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("event", {})
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("事件消息解析失败")
|
||||
return None
|
||||
149
astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py
Normal file
149
astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
企业微信智能机器人事件处理模块,处理消息事件的发送和接收
|
||||
"""
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Plain,
|
||||
)
|
||||
from astrbot.api import logger
|
||||
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_queue_mgr import wecomai_queue_mgr
|
||||
|
||||
|
||||
class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
"""企业微信智能机器人消息事件"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj,
|
||||
platform_meta,
|
||||
session_id: str,
|
||||
api_client: WecomAIBotAPIClient,
|
||||
):
|
||||
"""初始化消息事件
|
||||
|
||||
Args:
|
||||
message_str: 消息字符串
|
||||
message_obj: 消息对象
|
||||
platform_meta: 平台元数据
|
||||
session_id: 会话 ID
|
||||
api_client: API 客户端
|
||||
"""
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.api_client = api_client
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
message_chain: MessageChain,
|
||||
stream_id: str,
|
||||
streaming: bool = False,
|
||||
):
|
||||
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
|
||||
if not message_chain:
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
}
|
||||
)
|
||||
return ""
|
||||
|
||||
data = ""
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
data = comp.text
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"session_id": stream_id,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# 处理图片消息
|
||||
try:
|
||||
image_base64 = await comp.convert_to_base64()
|
||||
if image_base64:
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"image_data": image_base64,
|
||||
"streaming": streaming,
|
||||
"session_id": stream_id,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning("图片数据为空,跳过")
|
||||
except Exception as e:
|
||||
logger.error("处理图片消息失败: %s", e)
|
||||
else:
|
||||
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息"""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
"wecom_ai_bot platform event raw_message should be a dict"
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback=False):
|
||||
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||
final_data = ""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
"wecom_ai_bot platform event raw_message should be a dict"
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
|
||||
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
|
||||
increment_plain = ""
|
||||
async for chain in generator:
|
||||
# 累积增量内容,并改写 Plain 段
|
||||
chain.squash_plain()
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
comp.text = increment_plain + comp.text
|
||||
increment_plain = comp.text
|
||||
break
|
||||
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"session_id": self.session_id,
|
||||
}
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
|
||||
final_data += await WecomAIBotMessageEvent._send(
|
||||
chain,
|
||||
stream_id=stream_id,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"session_id": self.session_id,
|
||||
}
|
||||
)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
148
astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py
Normal file
148
astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
企业微信智能机器人队列管理器
|
||||
参考 webchat_queue_mgr.py,为企业微信智能机器人实现队列机制
|
||||
支持异步消息处理和流式响应
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class WecomAIQueueMgr:
|
||||
"""企业微信智能机器人队列管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.queues: Dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
|
||||
|
||||
self.back_queues: Dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输出队列的映射 - 用于发送机器人响应"""
|
||||
|
||||
self.pending_responses: Dict[str, Dict[str, Any]] = {}
|
||||
"""待处理的响应缓存,用于流式响应"""
|
||||
|
||||
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输入队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
输入队列实例
|
||||
"""
|
||||
if session_id not in self.queues:
|
||||
self.queues[session_id] = asyncio.Queue()
|
||||
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
|
||||
return self.queues[session_id]
|
||||
|
||||
def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输出队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
输出队列实例
|
||||
"""
|
||||
if session_id not in self.back_queues:
|
||||
self.back_queues[session_id] = asyncio.Queue()
|
||||
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
|
||||
return self.back_queues[session_id]
|
||||
|
||||
def remove_queues(self, session_id: str):
|
||||
"""移除指定会话的所有队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
|
||||
if session_id in self.back_queues:
|
||||
del self.back_queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输出队列: {session_id}")
|
||||
|
||||
if session_id in self.pending_responses:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||
|
||||
def has_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
是否存在队列
|
||||
"""
|
||||
return session_id in self.queues
|
||||
|
||||
def has_back_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的输出队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
是否存在输出队列
|
||||
"""
|
||||
return session_id in self.back_queues
|
||||
|
||||
def set_pending_response(self, session_id: str, callback_params: Dict[str, str]):
|
||||
"""设置待处理的响应参数
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
callback_params: 回调参数(nonce, timestamp等)
|
||||
"""
|
||||
self.pending_responses[session_id] = {
|
||||
"callback_params": callback_params,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||
|
||||
def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取待处理的响应参数
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
响应参数,如果不存在则返回None
|
||||
"""
|
||||
return self.pending_responses.get(session_id)
|
||||
|
||||
def cleanup_expired_responses(self, max_age_seconds: int = 300):
|
||||
"""清理过期的待处理响应
|
||||
|
||||
Args:
|
||||
max_age_seconds: 最大存活时间(秒)
|
||||
"""
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, response_data in self.pending_responses.items():
|
||||
if current_time - response_data["timestamp"] > max_age_seconds:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取队列统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return {
|
||||
"input_queues": len(self.queues),
|
||||
"output_queues": len(self.back_queues),
|
||||
"pending_responses": len(self.pending_responses),
|
||||
}
|
||||
|
||||
|
||||
# 全局队列管理器实例
|
||||
wecomai_queue_mgr = WecomAIQueueMgr()
|
||||
166
astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py
Normal file
166
astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
企业微信智能机器人 HTTP 服务器
|
||||
处理企业微信智能机器人的 HTTP 回调请求
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Callable
|
||||
|
||||
import quart
|
||||
from astrbot.api import logger
|
||||
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
|
||||
|
||||
class WecomAIBotServer:
|
||||
"""企业微信智能机器人 HTTP 服务器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
api_client: WecomAIBotAPIClient,
|
||||
message_handler: Optional[
|
||||
Callable[[Dict[str, Any], Dict[str, str]], Any]
|
||||
] = None,
|
||||
):
|
||||
"""初始化服务器
|
||||
|
||||
Args:
|
||||
host: 监听地址
|
||||
port: 监听端口
|
||||
api_client: API客户端实例
|
||||
message_handler: 消息处理回调函数
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.api_client = api_client
|
||||
self.message_handler = message_handler
|
||||
|
||||
self.app = quart.Quart(__name__)
|
||||
self._setup_routes()
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置 Quart 路由"""
|
||||
|
||||
# 使用 Quart 的 add_url_rule 方法添加路由
|
||||
self.app.add_url_rule(
|
||||
"/webhook/wecom-ai-bot",
|
||||
view_func=self.verify_url,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/webhook/wecom-ai-bot",
|
||||
view_func=self.handle_message,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def verify_url(self):
|
||||
"""验证回调 URL"""
|
||||
args = quart.request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
echostr = args.get("echostr")
|
||||
|
||||
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||
logger.error("URL 验证参数缺失")
|
||||
return "verify fail", 400
|
||||
|
||||
# 类型检查确保不为 None
|
||||
assert msg_signature is not None
|
||||
assert timestamp is not None
|
||||
assert nonce is not None
|
||||
assert echostr is not None
|
||||
|
||||
logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
|
||||
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
|
||||
return result, 200, {"Content-Type": "text/plain"}
|
||||
|
||||
async def handle_message(self):
|
||||
"""处理消息回调"""
|
||||
args = quart.request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
|
||||
if not all([msg_signature, timestamp, nonce]):
|
||||
logger.error("消息回调参数缺失")
|
||||
return "缺少必要参数", 400
|
||||
|
||||
# 类型检查确保不为 None
|
||||
assert msg_signature is not None
|
||||
assert timestamp is not None
|
||||
assert nonce is not None
|
||||
|
||||
logger.debug(
|
||||
f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取请求体
|
||||
post_data = await quart.request.get_data()
|
||||
|
||||
# 确保 post_data 是 bytes 类型
|
||||
if isinstance(post_data, str):
|
||||
post_data = post_data.encode("utf-8")
|
||||
|
||||
# 解密消息
|
||||
ret_code, message_data = await self.api_client.decrypt_message(
|
||||
post_data, msg_signature, timestamp, nonce
|
||||
)
|
||||
|
||||
if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
|
||||
logger.error("消息解密失败,错误码: %d", ret_code)
|
||||
return "消息解密失败", 400
|
||||
|
||||
# 调用消息处理器
|
||||
response = None
|
||||
if self.message_handler:
|
||||
try:
|
||||
response = await self.message_handler(
|
||||
message_data, {"nonce": nonce, "timestamp": timestamp}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("消息处理器执行异常: %s", e)
|
||||
return "消息处理异常", 500
|
||||
|
||||
if response:
|
||||
return response, 200, {"Content-Type": "text/plain"}
|
||||
else:
|
||||
return "success", 200, {"Content-Type": "text/plain"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("处理消息时发生异常: %s", e)
|
||||
return "内部服务器错误", 500
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器"""
|
||||
logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
|
||||
|
||||
try:
|
||||
await self.app.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("服务器运行异常: %s", e)
|
||||
raise
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
"""关闭触发器"""
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
logger.info("企业微信智能机器人服务器正在关闭...")
|
||||
self.shutdown_event.set()
|
||||
|
||||
def get_app(self):
|
||||
"""获取 Quart 应用实例"""
|
||||
return self.app
|
||||
199
astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py
Normal file
199
astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
企业微信智能机器人工具模块
|
||||
提供常量定义、工具函数和辅助方法
|
||||
"""
|
||||
|
||||
import string
|
||||
import random
|
||||
import hashlib
|
||||
import base64
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from Crypto.Cipher import AES
|
||||
from typing import Any, Tuple
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
# 常量定义
|
||||
class WecomAIBotConstants:
|
||||
"""企业微信智能机器人常量"""
|
||||
|
||||
# 消息类型
|
||||
MSG_TYPE_TEXT = "text"
|
||||
MSG_TYPE_IMAGE = "image"
|
||||
MSG_TYPE_MIXED = "mixed"
|
||||
MSG_TYPE_STREAM = "stream"
|
||||
MSG_TYPE_EVENT = "event"
|
||||
|
||||
# 流消息状态
|
||||
STREAM_CONTINUE = False
|
||||
STREAM_FINISH = True
|
||||
|
||||
# 错误码
|
||||
SUCCESS = 0
|
||||
DECRYPT_ERROR = -40001
|
||||
VALIDATE_SIGNATURE_ERROR = -40002
|
||||
PARSE_XML_ERROR = -40003
|
||||
COMPUTE_SIGNATURE_ERROR = -40004
|
||||
ILLEGAL_AES_KEY = -40005
|
||||
VALIDATE_APPID_ERROR = -40006
|
||||
ENCRYPT_AES_ERROR = -40007
|
||||
ILLEGAL_BUFFER = -40008
|
||||
|
||||
|
||||
def generate_random_string(length: int = 10) -> str:
|
||||
"""生成随机字符串
|
||||
|
||||
Args:
|
||||
length: 字符串长度,默认为 10
|
||||
|
||||
Returns:
|
||||
随机字符串
|
||||
"""
|
||||
letters = string.ascii_letters + string.digits
|
||||
return "".join(random.choice(letters) for _ in range(length))
|
||||
|
||||
|
||||
def calculate_image_md5(image_data: bytes) -> str:
|
||||
"""计算图片数据的 MD5 值
|
||||
|
||||
Args:
|
||||
image_data: 图片二进制数据
|
||||
|
||||
Returns:
|
||||
MD5 哈希值(十六进制字符串)
|
||||
"""
|
||||
return hashlib.md5(image_data).hexdigest()
|
||||
|
||||
|
||||
def encode_image_base64(image_data: bytes) -> str:
|
||||
"""将图片数据编码为 Base64
|
||||
|
||||
Args:
|
||||
image_data: 图片二进制数据
|
||||
|
||||
Returns:
|
||||
Base64 编码的字符串
|
||||
"""
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
|
||||
def format_session_id(session_type: str, session_id: str) -> str:
|
||||
"""格式化会话 ID
|
||||
|
||||
Args:
|
||||
session_type: 会话类型 ("user", "group")
|
||||
session_id: 原始会话 ID
|
||||
|
||||
Returns:
|
||||
格式化后的会话 ID
|
||||
"""
|
||||
return f"wecom_ai_bot_{session_type}_{session_id}"
|
||||
|
||||
|
||||
def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
|
||||
"""解析格式化的会话 ID
|
||||
|
||||
Args:
|
||||
formatted_session_id: 格式化的会话 ID
|
||||
|
||||
Returns:
|
||||
(会话类型, 原始会话ID)
|
||||
"""
|
||||
parts = formatted_session_id.split("_", 3)
|
||||
if (
|
||||
len(parts) >= 4
|
||||
and parts[0] == "wecom"
|
||||
and parts[1] == "ai"
|
||||
and parts[2] == "bot"
|
||||
):
|
||||
return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
|
||||
return "user", formatted_session_id
|
||||
|
||||
|
||||
def safe_json_loads(json_str: str, default: Any = None) -> Any:
|
||||
"""安全地解析 JSON 字符串
|
||||
|
||||
Args:
|
||||
json_str: JSON 字符串
|
||||
default: 解析失败时的默认值
|
||||
|
||||
Returns:
|
||||
解析结果或默认值
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
|
||||
return default
|
||||
|
||||
|
||||
def format_error_response(error_code: int, error_msg: str) -> str:
|
||||
"""格式化错误响应
|
||||
|
||||
Args:
|
||||
error_code: 错误码
|
||||
error_msg: 错误信息
|
||||
|
||||
Returns:
|
||||
格式化的错误响应字符串
|
||||
"""
|
||||
return f"Error {error_code}: {error_msg}"
|
||||
|
||||
|
||||
async def process_encrypted_image(
|
||||
image_url: str, aes_key_base64: str
|
||||
) -> Tuple[bool, str]:
|
||||
"""下载并解密加密图片
|
||||
|
||||
Args:
|
||||
image_url: 加密图片的URL
|
||||
aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
|
||||
status 为 False 时 data 是错误信息
|
||||
"""
|
||||
# 1. 下载加密图片
|
||||
logger.info("开始下载加密图片: %s", image_url)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=15) as response:
|
||||
response.raise_for_status()
|
||||
encrypted_data = await response.read()
|
||||
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
error_msg = f"下载图片失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 2. 准备AES密钥和IV
|
||||
if not aes_key_base64:
|
||||
raise ValueError("AES密钥不能为空")
|
||||
|
||||
# Base64解码密钥 (自动处理填充)
|
||||
aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
|
||||
if len(aes_key) != 32:
|
||||
raise ValueError("无效的AES密钥长度: 应为32字节")
|
||||
|
||||
iv = aes_key[:16] # 初始向量为密钥前16字节
|
||||
|
||||
# 3. 解密图片数据
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted_data = cipher.decrypt(encrypted_data)
|
||||
|
||||
# 4. 去除PKCS#7填充 (Python 3兼容写法)
|
||||
pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
|
||||
if pad_len > 32: # AES-256块大小为32字节
|
||||
raise ValueError("无效的填充长度 (大于32字节)")
|
||||
|
||||
decrypted_data = decrypted_data[:-pad_len]
|
||||
logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
|
||||
|
||||
# 5. 转换为base64编码
|
||||
base64_data = base64.b64encode(decrypted_data).decode("utf-8")
|
||||
logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data))
|
||||
|
||||
return True, base64_data
|
||||
@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
id=self.config.get("id", "weixin_official_account"),
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
ret = {
|
||||
ret: dict[str, str | list[dict]] = {
|
||||
"role": self.role,
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
if self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
tool_calls_dict = [
|
||||
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
|
||||
]
|
||||
ret["tool_calls"] = tool_calls_dict
|
||||
return ret
|
||||
|
||||
|
||||
@@ -117,7 +120,14 @@ class ProviderRequest:
|
||||
"""模型名称,为 None 时使用提供商的默认模型"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||
return (
|
||||
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
||||
f"image_count={len(self.image_urls or [])}, "
|
||||
f"func_tool={self.func_tool}, "
|
||||
f"contexts={self._print_friendly_context()}, "
|
||||
f"system_prompt={self.system_prompt}, "
|
||||
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
from typing import Dict, List, Awaitable
|
||||
from typing import Dict, List, Awaitable, Callable, Any
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
|
||||
@@ -109,7 +109,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Awaitable,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -132,7 +132,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Awaitable,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
) -> None:
|
||||
"""添加函数调用工具
|
||||
|
||||
@@ -220,7 +220,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
ready_future: asyncio.Future = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
try:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .provider import (
|
||||
Provider,
|
||||
STTProvider,
|
||||
TTSProvider,
|
||||
EmbeddingProvider,
|
||||
RerankProvider,
|
||||
)
|
||||
from .register import llm_tools, provider_cls_map
|
||||
from ..persona_mgr import PersonaManager
|
||||
|
||||
@@ -22,7 +27,7 @@ class ProviderManager:
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: List = config["provider"]
|
||||
self.providers_config: list = config["provider"]
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
@@ -30,15 +35,20 @@ class ProviderManager:
|
||||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||||
self.default_persona_name = persona_mgr.default_persona
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
self.provider_insts: list[Provider] = []
|
||||
"""加载的 Provider 的实例"""
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
self.stt_provider_insts: list[STTProvider] = []
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
self.tts_provider_insts: list[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
self.embedding_provider_insts: list[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map: dict[str, Provider] = {}
|
||||
self.rerank_provider_insts: list[RerankProvider] = []
|
||||
"""加载的 Rerank Provider 的实例"""
|
||||
self.inst_map: dict[
|
||||
str,
|
||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||
] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
@@ -87,19 +97,31 @@ class ProviderManager:
|
||||
)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
|
||||
prov = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
||||
prov, TTSProvider
|
||||
):
|
||||
self.curr_tts_provider_inst = prov
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||||
prov, STTProvider
|
||||
):
|
||||
self.curr_stt_provider_inst = prov
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||||
prov, Provider
|
||||
):
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
||||
def get_using_provider(
|
||||
self, provider_type: ProviderType, umo=None
|
||||
) -> Provider | STTProvider | TTSProvider | None:
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
@@ -152,7 +174,11 @@ class ProviderManager:
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
try:
|
||||
await self.load_provider(provider_config)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
@@ -211,6 +237,8 @@ class ProviderManager:
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "coze":
|
||||
from .sources.coze_source import ProviderCoze as ProviderCoze
|
||||
case "dashscope":
|
||||
from .sources.dashscope_source import (
|
||||
ProviderDashscope as ProviderDashscope,
|
||||
@@ -303,12 +331,14 @@ class ProviderManager:
|
||||
provider_metadata = provider_cls_map[provider_config["type"]]
|
||||
try:
|
||||
# 按任务实例化提供商
|
||||
cls_type = provider_metadata.cls_type
|
||||
if not cls_type:
|
||||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||||
return
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
@@ -327,9 +357,7 @@ class ProviderManager:
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
@@ -345,7 +373,7 @@ class ProviderManager:
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = provider_metadata.cls_type(
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.selected_default_persona,
|
||||
@@ -366,23 +394,25 @@ class ProviderManager:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type in [
|
||||
ProviderType.EMBEDDING,
|
||||
ProviderType.RERANK,
|
||||
]:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(
|
||||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
||||
)
|
||||
raise Exception(
|
||||
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
||||
)
|
||||
|
||||
async def reload(self, provider_config: dict):
|
||||
await self.terminate_provider(provider_config["id"])
|
||||
@@ -430,11 +460,17 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
if self.inst_map[provider_id] in self.provider_insts:
|
||||
self.provider_insts.remove(self.inst_map[provider_id])
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if isinstance(prov_inst, Provider):
|
||||
self.provider_insts.remove(prov_inst)
|
||||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||
self.stt_provider_insts.remove(self.inst_map[provider_id])
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if isinstance(prov_inst, STTProvider):
|
||||
self.stt_provider_insts.remove(prov_inst)
|
||||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||
self.tts_provider_insts.remove(self.inst_map[provider_id])
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if isinstance(prov_inst, TTSProvider):
|
||||
self.tts_provider_insts.remove(prov_inst)
|
||||
|
||||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||
self.curr_provider_inst = None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from typing import List
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
@@ -68,14 +69,15 @@ class Provider(AbstractProvider):
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""获得提供商 Key"""
|
||||
return self.provider_config.get("key", [])
|
||||
keys = self.provider_config.get("key", [""])
|
||||
return keys or [""]
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_key(self, key: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_models(self) -> List[str]:
|
||||
async def get_models(self) -> List[str]:
|
||||
"""获得支持的模型列表"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -202,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
async def get_embeddings_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
batch_size: int = 16,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[list[float]]:
|
||||
"""批量获取文本的向量,分批处理以节省内存
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 每批处理的文本数量
|
||||
tasks_limit: 并发任务数量限制
|
||||
max_retries: 失败时的最大重试次数
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(tasks_limit)
|
||||
all_embeddings: list[list[float]] = []
|
||||
failed_batches: list[tuple[int, list[str]]] = []
|
||||
completed_count = 0
|
||||
total_count = len(texts)
|
||||
|
||||
async def process_batch(batch_idx: int, batch_texts: list[str]):
|
||||
nonlocal completed_count
|
||||
async with semaphore:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
batch_embeddings = await self.get_embeddings(batch_texts)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
completed_count += len(batch_texts)
|
||||
if progress_callback:
|
||||
await progress_callback(completed_count, total_count)
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
# 最后一次重试失败,记录失败的批次
|
||||
failed_batches.append((batch_idx, batch_texts))
|
||||
raise Exception(
|
||||
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
|
||||
)
|
||||
# 等待一段时间后重试,使用指数退避
|
||||
await asyncio.sleep(2**attempt)
|
||||
|
||||
tasks = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
batch_idx = i // batch_size
|
||||
tasks.append(process_batch(batch_idx, batch_texts))
|
||||
|
||||
# 收集所有任务的结果,包括失败的任务
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 检查是否有失败的任务
|
||||
errors = [r for r in results if isinstance(r, Exception)]
|
||||
if errors:
|
||||
error_msg = (
|
||||
f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
|
||||
)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
class RerankProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
|
||||
@@ -10,7 +10,7 @@ from anthropic.types import Message
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from typing import AsyncGenerator
|
||||
@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
|
||||
)
|
||||
|
||||
self.chosen_api_key: str = ""
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.api_keys: List = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
@@ -70,9 +70,13 @@ class ProviderAnthropic(Provider):
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], str)
|
||||
else tool_call["function"]["arguments"],
|
||||
"input": (
|
||||
json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(
|
||||
tool_call["function"]["arguments"], str
|
||||
)
|
||||
else tool_call["function"]["arguments"]
|
||||
),
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
)
|
||||
@@ -100,7 +104,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
@@ -131,7 +135,7 @@ class ProviderAnthropic(Provider):
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
self, payloads: dict, tools: ToolSet | None
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
@@ -322,7 +326,7 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
@@ -355,9 +359,11 @@ class ProviderAnthropic(Provider):
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
314
astrbot/core/provider/sources/coze_api_client.py
Normal file
314
astrbot/core/provider/sources/coze_api_client.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import io
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class CozeAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = None
|
||||
|
||||
async def _ensure_session(self):
|
||||
"""确保HTTP session存在"""
|
||||
if self.session is None:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False if self.api_base.startswith("http://") else True,
|
||||
limit=100,
|
||||
limit_per_host=30,
|
||||
keepalive_timeout=30,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=120, # 默认超时时间
|
||||
connect=30,
|
||||
sock_read=120,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=headers, timeout=timeout, connector=connector
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id
|
||||
|
||||
Args:
|
||||
file_data (bytes): 文件的二进制数据
|
||||
Returns:
|
||||
str: 上传成功后返回的 file_id
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v1/files/upload"
|
||||
|
||||
try:
|
||||
file_io = io.BytesIO(file_data)
|
||||
async with session.post(
|
||||
url,
|
||||
data={
|
||||
"file": file_io,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(
|
||||
f"文件上传响应状态: {response.status}, 内容: {response_text}"
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"文件上传响应解析失败: {response_text}")
|
||||
|
||||
if result.get("code") != 0:
|
||||
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
||||
|
||||
file_id = result["data"]["id"]
|
||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("文件上传超时")
|
||||
raise Exception("文件上传超时")
|
||||
except Exception as e:
|
||||
logger.error(f"文件上传失败: {str(e)}")
|
||||
raise Exception(f"文件上传失败: {str(e)}")
|
||||
|
||||
async def download_image(self, image_url: str) -> bytes:
|
||||
"""下载图片并返回字节数据
|
||||
|
||||
Args:
|
||||
image_url (str): 图片的URL
|
||||
Returns:
|
||||
bytes: 图片的二进制数据
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
|
||||
try:
|
||||
async with session.get(image_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"下载图片失败,状态码: {response.status}")
|
||||
|
||||
image_data = await response.read()
|
||||
return image_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败 {image_url}: {str(e)}")
|
||||
raise Exception(f"下载图片失败: {str(e)}")
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
bot_id: str,
|
||||
user_id: str,
|
||||
additional_messages: List[Dict] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
auto_save_history: bool = True,
|
||||
stream: bool = True,
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""发送聊天消息并返回流式响应
|
||||
|
||||
Args:
|
||||
bot_id: Bot ID
|
||||
user_id: 用户ID
|
||||
additional_messages: 额外消息列表
|
||||
conversation_id: 会话ID
|
||||
auto_save_history: 是否自动保存历史
|
||||
stream: 是否流式响应
|
||||
timeout: 超时时间
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/chat"
|
||||
|
||||
payload = {
|
||||
"bot_id": bot_id,
|
||||
"user_id": user_id,
|
||||
"stream": stream,
|
||||
"auto_save_history": auto_save_history,
|
||||
}
|
||||
|
||||
if additional_messages:
|
||||
payload["additional_messages"] = additional_messages
|
||||
|
||||
params = {}
|
||||
if conversation_id:
|
||||
params["conversation_id"] = conversation_id
|
||||
|
||||
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||
|
||||
# SSE
|
||||
buffer = ""
|
||||
event_type = None
|
||||
event_data = None
|
||||
|
||||
async for chunk in response.content:
|
||||
if chunk:
|
||||
buffer += chunk.decode("utf-8", errors="ignore")
|
||||
lines = buffer.split("\n")
|
||||
buffer = lines[-1]
|
||||
|
||||
for line in lines[:-1]:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
if event_type and event_data:
|
||||
yield {"event": event_type, "data": event_data}
|
||||
event_type = None
|
||||
event_data = None
|
||||
elif line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if data_str and data_str != "[DONE]":
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
event_data = {"content": data_str}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
||||
except Exception as e:
|
||||
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
||||
|
||||
async def clear_context(self, conversation_id: str):
|
||||
"""清空会话上下文
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
Returns:
|
||||
dict: API响应结果
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/conversation/message/clear_context"
|
||||
payload = {"conversation_id": conversation_id}
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
||||
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Coze API 返回非JSON格式")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception("Coze API 请求超时")
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"Coze API 请求失败: {str(e)}")
|
||||
|
||||
async def get_message_list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
order: str = "desc",
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""获取消息列表
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
order: 排序方式 (asc/desc)
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
Returns:
|
||||
dict: API响应结果
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/conversation/message/list"
|
||||
params = {
|
||||
"conversation_id": conversation_id,
|
||||
"order": order,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Coze消息列表失败: {str(e)}")
|
||||
raise Exception(f"获取Coze消息列表失败: {str(e)}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭会话"""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
async def test_coze_api_client():
|
||||
api_key = os.getenv("COZE_API_KEY", "")
|
||||
bot_id = os.getenv("COZE_BOT_ID", "")
|
||||
client = CozeAPIClient(api_key=api_key)
|
||||
|
||||
try:
|
||||
with open("README.md", "rb") as f:
|
||||
file_data = f.read()
|
||||
file_id = await client.upload_file(file_data)
|
||||
print(f"Uploaded file_id: {file_id}")
|
||||
async for event in client.chat_messages(
|
||||
bot_id=bot_id,
|
||||
user_id="test_user",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
[
|
||||
{"type": "text", "text": "这是什么"},
|
||||
{"type": "file", "file_id": file_id},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
"content_type": "object_string",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
print(f"Event: {event}")
|
||||
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(test_coze_api_client())
|
||||
635
astrbot/core/provider/sources/coze_source.py
Normal file
635
astrbot/core/provider/sources/coze_source.py
Normal file
@@ -0,0 +1,635 @@
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import AsyncGenerator, Dict
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
|
||||
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
||||
class ProviderCoze(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://")
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
self.conversation_ids: Dict[str, str] = {}
|
||||
self.file_id_cache: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
||||
"""生成统一的缓存键
|
||||
|
||||
Args:
|
||||
data: 图片数据或路径
|
||||
is_base64: 是否是 base64 数据
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
|
||||
try:
|
||||
if is_base64 and data.startswith("data:image/"):
|
||||
try:
|
||||
header, encoded = data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
cache_key = hashlib.md5(image_bytes).hexdigest()
|
||||
return cache_key
|
||||
except Exception:
|
||||
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
if data.startswith(("http://", "https://")):
|
||||
# URL图片,使用URL作为缓存键
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
clean_path = (
|
||||
data.split("_")[0]
|
||||
if "_" in data and len(data.split("_")) >= 3
|
||||
else data
|
||||
)
|
||||
|
||||
if os.path.exists(clean_path):
|
||||
with open(clean_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
cache_key = hashlib.md5(file_content).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
||||
return cache_key
|
||||
|
||||
async def _upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
session_id: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id"""
|
||||
# 使用 API 客户端上传文件
|
||||
file_id = await self.api_client.upload_file(file_data)
|
||||
|
||||
# 缓存 file_id
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self, image_url: str, session_id: str | None = None
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
# 计算哈希实现缓存
|
||||
cache_key = self._generate_cache_key(image_url) if session_id else None
|
||||
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
|
||||
file_id = await self._upload_file(image_data, session_id, cache_key)
|
||||
|
||||
if session_id and cache_key:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {str(e)}")
|
||||
raise Exception(f"处理图片失败: {str(e)}")
|
||||
|
||||
async def _process_context_images(
|
||||
self, content: str | list, session_id: str
|
||||
) -> str:
|
||||
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
processed_content = []
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
processed_content.append(item)
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片逻辑
|
||||
if "file_id" in item:
|
||||
# 已经有 file_id
|
||||
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
||||
processed_content.append(item)
|
||||
else:
|
||||
# 获取图片数据
|
||||
image_data = ""
|
||||
if "image_url" in item and isinstance(item["image_url"], dict):
|
||||
image_data = item["image_url"].get("url", "")
|
||||
elif "data" in item:
|
||||
image_data = item.get("data", "")
|
||||
elif "url" in item:
|
||||
image_data = item.get("url", "")
|
||||
|
||||
if not image_data:
|
||||
continue
|
||||
# 计算哈希用于缓存
|
||||
cache_key = self._generate_cache_key(
|
||||
image_data, is_base64=image_data.startswith("data:image/")
|
||||
)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id}
|
||||
)
|
||||
else:
|
||||
# 上传图片并缓存
|
||||
if image_data.startswith("data:image/"):
|
||||
# base64 处理
|
||||
_, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
elif image_data.startswith(("http://", "https://")):
|
||||
# URL 图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
image_data, session_id
|
||||
)
|
||||
# 为URL图片也添加缓存
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
elif os.path.exists(image_data):
|
||||
# 本地文件
|
||||
with open(image_data, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法处理的图片格式: {image_data[:50]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id}
|
||||
)
|
||||
|
||||
result = json.dumps(processed_content, ensure_ascii=False)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"处理上下文图片失败: {str(e)}")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
else:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""文本对话, 内部使用流式接口实现非流式
|
||||
|
||||
Args:
|
||||
prompt (str): 用户提示词
|
||||
session_id (str): 会话ID
|
||||
image_urls (List[str]): 图片URL列表
|
||||
func_tool (FuncCall): 函数调用工具(不支持)
|
||||
contexts (List): 上下文列表
|
||||
system_prompt (str): 系统提示语
|
||||
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
||||
model (str): 模型名称(不支持)
|
||||
Returns:
|
||||
LLMResponse: LLM响应对象
|
||||
"""
|
||||
accumulated_content = ""
|
||||
final_response = None
|
||||
|
||||
async for llm_response in self.text_chat_stream(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
model=model,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.completion_text:
|
||||
accumulated_content += llm_response.completion_text
|
||||
else:
|
||||
final_response = llm_response
|
||||
|
||||
if final_response:
|
||||
return final_response
|
||||
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
else:
|
||||
return LLMResponse(role="assistant", completion_text="")
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话接口"""
|
||||
# 用户ID参数(参考文档, 可以自定义)
|
||||
user_id = session_id or kwargs.get("user", "default_user")
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{"role": "system", "content": system_prompt, "content_type": "text"}
|
||||
)
|
||||
|
||||
if not self.auto_save_history and contexts:
|
||||
# 如果关闭了自动保存历史,传入上下文
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
content = ctx["content"]
|
||||
content_type = ctx.get("content_type", "text")
|
||||
|
||||
# 处理可能包含图片的上下文
|
||||
if (
|
||||
content_type == "object_string"
|
||||
or (isinstance(content, str) and content.startswith("["))
|
||||
or (
|
||||
isinstance(content, list)
|
||||
and any(
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "image_url"
|
||||
for item in content
|
||||
)
|
||||
)
|
||||
):
|
||||
processed_content = await self._process_context_images(
|
||||
content, user_id
|
||||
)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": (
|
||||
content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content, ensure_ascii=False)
|
||||
),
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
||||
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
try:
|
||||
if url.startswith(("http://", "https://")):
|
||||
# 网络图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
url, user_id
|
||||
)
|
||||
else:
|
||||
# 本地文件或 base64
|
||||
if url.startswith("data:image/"):
|
||||
# base64
|
||||
_, encoded = url.split(",", 1)
|
||||
image_data = base64.b64decode(encoded)
|
||||
cache_key = self._generate_cache_key(
|
||||
url, is_base64=True
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data, user_id, cache_key
|
||||
)
|
||||
else:
|
||||
# 本地文件
|
||||
if os.path.exists(url):
|
||||
with open(url, "rb") as f:
|
||||
image_data = f.read()
|
||||
# 用文件路径和修改时间来缓存
|
||||
file_stat = os.stat(url)
|
||||
cache_key = self._generate_cache_key(
|
||||
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
||||
is_base64=False,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data, user_id, cache_key
|
||||
)
|
||||
else:
|
||||
logger.warning(f"图片文件不存在: {url}")
|
||||
continue
|
||||
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {url}: {str(e)}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
if prompt:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
self.conversation_ids[user_id] = data["conversation_id"]
|
||||
|
||||
elif event_type == "conversation.message.delta":
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
message_started = True
|
||||
accumulated_content += content
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=content,
|
||||
is_chunk=True,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
if isinstance(data, dict):
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "answer" and data.get("role") == "assistant":
|
||||
final_content = data.get("content", "")
|
||||
if not accumulated_content and final_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
elif event_type == "done":
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
error_msg = (
|
||||
data.get("message", "未知错误")
|
||||
if isinstance(data, dict)
|
||||
else str(data)
|
||||
)
|
||||
logger.error(f"Coze 流式响应错误: {error_msg}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 错误: {error_msg}",
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="LLM 未响应任何内容。",
|
||||
is_chunk=False,
|
||||
)
|
||||
elif message_started and accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 流式请求失败: {str(e)}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 流式请求失败: {str(e)}",
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
async def forget(self, session_id: str):
|
||||
"""清空指定会话的上下文"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if user_id in self.file_id_cache:
|
||||
self.file_id_cache.pop(user_id, None)
|
||||
|
||||
if not conversation_id:
|
||||
return True
|
||||
|
||||
try:
|
||||
response = await self.api_client.clear_context(conversation_id)
|
||||
|
||||
if "code" in response and response["code"] == 0:
|
||||
self.conversation_ids.pop(user_id, None)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空 Coze 会话失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_current_key(self):
|
||||
"""获取当前API Key"""
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key: str):
|
||||
"""设置新的API Key"""
|
||||
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
"""获取可用模型列表"""
|
||||
return [f"bot_{self.bot_id}"]
|
||||
|
||||
def get_model(self):
|
||||
"""获取当前模型"""
|
||||
return f"bot_{self.bot_id}"
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型(在Coze中是Bot ID)"""
|
||||
if model.startswith("bot_"):
|
||||
self.bot_id = model[4:]
|
||||
else:
|
||||
self.bot_id = model
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, session_id: str, page: int = 1, page_size: int = 10
|
||||
):
|
||||
"""获取人类可读的上下文历史"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if not conversation_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await self.api_client.get_message_list(
|
||||
conversation_id=conversation_id,
|
||||
order="desc",
|
||||
limit=page_size,
|
||||
offset=(page - 1) * page_size,
|
||||
)
|
||||
|
||||
if data.get("code") != 0:
|
||||
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
||||
return []
|
||||
|
||||
messages = data.get("data", {}).get("messages", [])
|
||||
|
||||
readable_history = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if role == "user":
|
||||
readable_history.append(f"用户: {content}")
|
||||
elif role == "assistant" and msg_type == "answer":
|
||||
readable_history.append(f"助手: {content}")
|
||||
|
||||
return readable_history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def terminate(self):
|
||||
"""清理资源"""
|
||||
await self.api_client.close()
|
||||
@@ -1,15 +1,14 @@
|
||||
import re
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core import logger, sp
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
|
||||
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
session_id=None,
|
||||
image_urls=[],
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
assert isinstance(response, ApplicationResponse)
|
||||
|
||||
logger.debug(f"dashscope resp: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "")
|
||||
output_text = response.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_str = ""
|
||||
for ref in response.output.get("doc_references", []):
|
||||
for ref in response.output.get("doc_references", []) or []:
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
import os
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
import aiohttp
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
|
||||
|
||||
try:
|
||||
from dashscope.aigc.multimodal_conversation import MultiModalConversation
|
||||
except (
|
||||
ImportError
|
||||
): # pragma: no cover - older dashscope versions without Qwen TTS support
|
||||
MultiModalConversation = None
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
@@ -26,16 +38,112 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
model = self.get_model()
|
||||
if not model:
|
||||
raise RuntimeError("Dashscope TTS model is not configured.")
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
if self._is_qwen_tts_model(model):
|
||||
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
|
||||
else:
|
||||
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
|
||||
|
||||
if not audio_bytes:
|
||||
raise RuntimeError(
|
||||
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
|
||||
)
|
||||
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
|
||||
with open(path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
return path
|
||||
|
||||
def _call_qwen_tts(self, model: str, text: str):
|
||||
if MultiModalConversation is None:
|
||||
raise RuntimeError(
|
||||
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"api_key": self.chosen_api_key,
|
||||
"voice": self.voice or "Cherry",
|
||||
}
|
||||
if not self.voice:
|
||||
logging.warning(
|
||||
"No voice specified for Qwen TTS model, using default 'Cherry'."
|
||||
)
|
||||
return MultiModalConversation.call(**kwargs)
|
||||
|
||||
async def _synthesize_with_qwen_tts(
|
||||
self, model: str, text: str
|
||||
) -> Tuple[Optional[bytes], str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||
audio_bytes = await self._extract_audio_from_response(response)
|
||||
if not audio_bytes:
|
||||
raise RuntimeError(
|
||||
f"Audio synthesis failed for model '{model}'. {response}"
|
||||
)
|
||||
ext = ".wav"
|
||||
return audio_bytes, ext
|
||||
|
||||
async def _extract_audio_from_response(self, response) -> Optional[bytes]:
|
||||
output = getattr(response, "output", None)
|
||||
audio_obj = getattr(output, "audio", None) if output is not None else None
|
||||
if not audio_obj:
|
||||
return None
|
||||
|
||||
data_b64 = getattr(audio_obj, "data", None)
|
||||
if data_b64:
|
||||
try:
|
||||
return base64.b64decode(data_b64)
|
||||
except (ValueError, TypeError):
|
||||
logging.error("Failed to decode base64 audio data.")
|
||||
return None
|
||||
|
||||
url = getattr(audio_obj, "url", None)
|
||||
if url:
|
||||
return await self._download_audio_from_url(url)
|
||||
return None
|
||||
|
||||
async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
|
||||
if not url:
|
||||
return None
|
||||
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as response:
|
||||
return await response.read()
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
||||
logging.error(f"Failed to download audio from URL {url}: {e}")
|
||||
return None
|
||||
|
||||
async def _synthesize_with_cosyvoice(
|
||||
self, model: str, text: str
|
||||
) -> Tuple[Optional[bytes], str]:
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=model,
|
||||
voice=self.voice,
|
||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||
)
|
||||
audio = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.synthesizer.call, text, self.timeout_ms
|
||||
loop = asyncio.get_event_loop()
|
||||
audio_bytes = await loop.run_in_executor(
|
||||
None, synthesizer.call, text, self.timeout_ms
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
f.write(audio)
|
||||
return path
|
||||
if not audio_bytes:
|
||||
resp = synthesizer.get_response()
|
||||
if resp and isinstance(resp, dict):
|
||||
raise RuntimeError(
|
||||
f"Audio synthesis failed for model '{model}'. {resp}".strip()
|
||||
)
|
||||
return audio_bytes, ".wav"
|
||||
|
||||
def _is_qwen_tts_model(self, model: str) -> bool:
|
||||
model_lower = model.lower()
|
||||
return "tts" in model_lower and model_lower.startswith("qwen")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user