Compare commits
428 Commits
v3.4.0
...
feat-webui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b54a24037 | ||
|
|
2807e1e892 | ||
|
|
0a2abd8214 | ||
|
|
8beb7acdb1 | ||
|
|
466c80b94d | ||
|
|
36c0cfc9a9 | ||
|
|
35ba1b3345 | ||
|
|
d00821d1c7 | ||
|
|
6c1b3f242b | ||
|
|
9f9da1e0c9 | ||
|
|
14fb4b70bd | ||
|
|
b1049540a4 | ||
|
|
5e2909df33 | ||
|
|
c122dad21f | ||
|
|
48ae686602 | ||
|
|
bf2c3a1a81 | ||
|
|
96e7a93886 | ||
|
|
dba1ed1e19 | ||
|
|
a24514876b | ||
|
|
466a1c1c41 | ||
|
|
a2d5e9f40f | ||
|
|
1bbff1d161 | ||
|
|
0948bae99b | ||
|
|
850db41596 | ||
|
|
7bafc87e2b | ||
|
|
1a0de02a15 | ||
|
|
6d5d278624 | ||
|
|
3b4cc48fa0 | ||
|
|
c908461088 | ||
|
|
53d1398d30 | ||
|
|
782c0367d0 | ||
|
|
4678222e9b | ||
|
|
f71dc3e4be | ||
|
|
f6233893bd | ||
|
|
6427bcf130 | ||
|
|
8fa41b706c | ||
|
|
4706c4438d | ||
|
|
0c8ebc2b06 | ||
|
|
b3b5ebc2ca | ||
|
|
b8aa23ccc5 | ||
|
|
364843db29 | ||
|
|
aa56c8f7e6 | ||
|
|
8e9fd27058 | ||
|
|
b75908cb2a | ||
|
|
af6df49ce1 | ||
|
|
bd3bdb5769 | ||
|
|
98fe193b21 | ||
|
|
26cbc9e8b1 | ||
|
|
ebb8c43fd0 | ||
|
|
8c7344f1c4 | ||
|
|
5c32a17787 | ||
|
|
aff520e69a | ||
|
|
45e627c33c | ||
|
|
7a1b158f83 | ||
|
|
6374c5d49d | ||
|
|
fd460b19d4 | ||
|
|
dff7cc4ca5 | ||
|
|
d013320bec | ||
|
|
fc6dcfaf21 | ||
|
|
a001270bd2 | ||
|
|
9e67883fbd | ||
|
|
f1a448708c | ||
|
|
a4bfa96502 | ||
|
|
595b83a256 | ||
|
|
8d34f77321 | ||
|
|
67095f97b1 | ||
|
|
50740c94ab | ||
|
|
4db4cfeda2 | ||
|
|
ad13cef89c | ||
|
|
855fc6fcd1 | ||
|
|
8f12244e51 | ||
|
|
fe0213465c | ||
|
|
f984047004 | ||
|
|
19e9e2d090 | ||
|
|
7fe3b97d00 | ||
|
|
9cd243da47 | ||
|
|
e43208c2e9 | ||
|
|
dc016fc22f | ||
|
|
c6f037cae2 | ||
|
|
f049830e28 | ||
|
|
dd1995ae0b | ||
|
|
23dc233569 | ||
|
|
0977aa7d0d | ||
|
|
24862b0672 | ||
|
|
f05a57efc3 | ||
|
|
65331a9d7c | ||
|
|
f7ae287e40 | ||
|
|
45f380b1f6 | ||
|
|
9e6b329df4 | ||
|
|
43cd34d94c | ||
|
|
9fa00aff9a | ||
|
|
9a56dcb1be | ||
|
|
fdfe7bbe59 | ||
|
|
3a99a60792 | ||
|
|
fa2b4e14df | ||
|
|
35322a6900 | ||
|
|
2ccf29d61e | ||
|
|
b068013343 | ||
|
|
d839e72998 | ||
|
|
d7c9a8ed29 | ||
|
|
6837d4d692 | ||
|
|
8aba83735b | ||
|
|
aa51187747 | ||
|
|
5f07a9ae95 | ||
|
|
a2ca767bf4 | ||
|
|
5806c74e7c | ||
|
|
0481e1d45e | ||
|
|
3177b61421 | ||
|
|
6009cf5dfa | ||
|
|
0a970e8c31 | ||
|
|
aa276ca6af | ||
|
|
9f02dd13ff | ||
|
|
609e723322 | ||
|
|
c564a1d53e | ||
|
|
a7fe31f28b | ||
|
|
a84dc599d6 | ||
|
|
8da029add9 | ||
|
|
ba45a2d270 | ||
|
|
cb56b22aea | ||
|
|
23cc5b31ba | ||
|
|
e8d99f0460 | ||
|
|
6bcd10cd5c | ||
|
|
619fb20c5f | ||
|
|
386a312e96 | ||
|
|
2759d347e6 | ||
|
|
b6ec327b49 | ||
|
|
ee02d622ba | ||
|
|
5c4a6083f5 | ||
|
|
49e63a3d3d | ||
|
|
6bae9dc9ed | ||
|
|
5fa1979a46 | ||
|
|
b40d4fa315 | ||
|
|
4d2ff7cd5b | ||
|
|
d8ec0e64d0 | ||
|
|
82e979cc07 | ||
|
|
8c132a51f5 | ||
|
|
40bd372cc1 | ||
|
|
212e114270 | ||
|
|
b0e9de6951 | ||
|
|
3489522bbb | ||
|
|
96237abc03 | ||
|
|
7155b4f0ac | ||
|
|
a8b2b09e0f | ||
|
|
6858b8c555 | ||
|
|
0e493b1a0e | ||
|
|
37d478f970 | ||
|
|
7d0d42a49f | ||
|
|
0eb1684ef1 | ||
|
|
9b0b723143 | ||
|
|
532bc6e1e6 | ||
|
|
fe3ed4c454 | ||
|
|
b5ec89e586 | ||
|
|
895e7397c2 | ||
|
|
59b767957a | ||
|
|
17d4bf8f22 | ||
|
|
836be3b097 | ||
|
|
310415bea9 | ||
|
|
aafc1276a9 | ||
|
|
2993e794cc | ||
|
|
58cb9cfb2d | ||
|
|
fbdf0901d5 | ||
|
|
af8c81b621 | ||
|
|
06b5275e48 | ||
|
|
ad95572d5f | ||
|
|
0021cfc4bc | ||
|
|
aebc7850f4 | ||
|
|
1b7efbc607 | ||
|
|
3800e96d14 | ||
|
|
461f1bb07c | ||
|
|
7d4c07e4f6 | ||
|
|
31b788f463 | ||
|
|
96ab761f73 | ||
|
|
2b3f05c039 | ||
|
|
f2e8303b66 | ||
|
|
2a614b545b | ||
|
|
5c0ab21f68 | ||
|
|
689d109438 | ||
|
|
2a6934b283 | ||
|
|
760cb94e9a | ||
|
|
2a6cff0013 | ||
|
|
ce578f0417 | ||
|
|
1745bdb9e2 | ||
|
|
3f90b89c3c | ||
|
|
f343e40d15 | ||
|
|
5cc4be9e65 | ||
|
|
da5aada002 | ||
|
|
07f2ee9ad9 | ||
|
|
12f4e1146f | ||
|
|
92c57e5476 | ||
|
|
a923baacd8 | ||
|
|
999b094d55 | ||
|
|
d4213f2352 | ||
|
|
3f65c9a066 | ||
|
|
1d427e2645 | ||
|
|
36414c4b00 | ||
|
|
47e253d76c | ||
|
|
b73cf84df0 | ||
|
|
a5b885a774 | ||
|
|
0c785413da | ||
|
|
482d7ef5f7 | ||
|
|
9f9073c0ff | ||
|
|
ef05ff4abd | ||
|
|
5848aae435 | ||
|
|
fb06f33de0 | ||
|
|
0d7ddb149e | ||
|
|
4f2d7b9c4e | ||
|
|
c02ed96f6f | ||
|
|
3b2ac891b2 | ||
|
|
ef0108881b | ||
|
|
af48975a6b | ||
|
|
6441b149ab | ||
|
|
f8892881f8 | ||
|
|
228aec5401 | ||
|
|
68ad48ff55 | ||
|
|
541ba64032 | ||
|
|
2d870b798c | ||
|
|
0f1fe1ab63 | ||
|
|
73cc86ddb1 | ||
|
|
23128f4be2 | ||
|
|
92200d0e82 | ||
|
|
d6e8655792 | ||
|
|
37076d7920 | ||
|
|
78347ec91b | ||
|
|
9ded102a0a | ||
|
|
59b7d8b8cb | ||
|
|
f5b97f6762 | ||
|
|
d47da241af | ||
|
|
4611ce15eb | ||
|
|
aa8c56a688 | ||
|
|
ef44d4471a | ||
|
|
5581eae957 | ||
|
|
ec46dfaac9 | ||
|
|
6042a047bd | ||
|
|
6ca9e2a753 | ||
|
|
618eabfe5c | ||
|
|
bb5db2e9d0 | ||
|
|
97e4d169b3 | ||
|
|
50e44b1473 | ||
|
|
38588dd3fa | ||
|
|
d183388347 | ||
|
|
1e69d59384 | ||
|
|
00f008f94d | ||
|
|
3c28001a74 | ||
|
|
76a6218be6 | ||
|
|
6c1de1bbd6 | ||
|
|
d7678081da | ||
|
|
5e4ba563cb | ||
|
|
8afbe77b0a | ||
|
|
2ef139b59a | ||
|
|
1f0d2d9b89 | ||
|
|
37a1f144ab | ||
|
|
9a7a654596 | ||
|
|
9abccd63cf | ||
|
|
93fea77182 | ||
|
|
19797243f6 | ||
|
|
c9c733d925 | ||
|
|
a7d7678c78 | ||
|
|
c0911921c7 | ||
|
|
4a4241d57a | ||
|
|
c9426bb6eb | ||
|
|
db4abd169a | ||
|
|
80b6958599 | ||
|
|
80058c781a | ||
|
|
44bd2e36f3 | ||
|
|
3589a5e5be | ||
|
|
13ef033f0e | ||
|
|
3f8c68bbca | ||
|
|
4275cea82b | ||
|
|
a0bcb5339a | ||
|
|
43deec4a4b | ||
|
|
2bc433a30b | ||
|
|
eb2b395932 | ||
|
|
2bfd1c0bf2 | ||
|
|
7228c4b13f | ||
|
|
9351d7471f | ||
|
|
1cf49998bc | ||
|
|
6ae86597e8 | ||
|
|
c578ff25bd | ||
|
|
2934a3e3be | ||
|
|
ceaa69da75 | ||
|
|
fa8e731576 | ||
|
|
685c0a106a | ||
|
|
7f539090dd | ||
|
|
2089273f95 | ||
|
|
838bb4c7ad | ||
|
|
637acd1a12 | ||
|
|
03fa9a847f | ||
|
|
d488c88e78 | ||
|
|
baae842210 | ||
|
|
ec1fb838b6 | ||
|
|
13281179df | ||
|
|
276a42c9a1 | ||
|
|
7a70a730ba | ||
|
|
d0fe59631c | ||
|
|
106892e933 | ||
|
|
19543a41b3 | ||
|
|
b172b760ab | ||
|
|
4b5d49cb41 | ||
|
|
3fd35b6058 | ||
|
|
5f86c4ab99 | ||
|
|
c94a7f6629 | ||
|
|
7d6beb4141 | ||
|
|
e2117e690a | ||
|
|
fb791290e2 | ||
|
|
5dd1488b5d | ||
|
|
529cd64d82 | ||
|
|
d2bd3e8da8 | ||
|
|
e42ce7dd86 | ||
|
|
40709462ee | ||
|
|
2ad6c01a4d | ||
|
|
70c12e788e | ||
|
|
1713791c90 | ||
|
|
9aa23fd412 | ||
|
|
e4ba09cd93 | ||
|
|
171fdf1fbc | ||
|
|
01f4e0b961 | ||
|
|
be2d5a91c7 | ||
|
|
a1d89d9478 | ||
|
|
98d1dc3b65 | ||
|
|
b80eb3acc0 | ||
|
|
05ccc1995b | ||
|
|
0de244889e | ||
|
|
e6c5c3a493 | ||
|
|
164aa2ccd2 | ||
|
|
f1599e26b3 | ||
|
|
ed64a4d32d | ||
|
|
2ee4b431d4 | ||
|
|
cd8a73ed19 | ||
|
|
e6c985ce4e | ||
|
|
a20446aeb9 | ||
|
|
7b23d76559 | ||
|
|
8315cf5818 | ||
|
|
ed16265bde | ||
|
|
dff205faf6 | ||
|
|
9aae8aee0c | ||
|
|
7c818ced2b | ||
|
|
218e887558 | ||
|
|
a68860b35a | ||
|
|
82d4d43383 | ||
|
|
94618e8feb | ||
|
|
55de7d4494 | ||
|
|
7ed639f741 | ||
|
|
41f2870c29 | ||
|
|
ba198490fa | ||
|
|
0f9ab082ab | ||
|
|
97b58965f2 | ||
|
|
f2566c68e3 | ||
|
|
a456bf5449 | ||
|
|
a09998f910 | ||
|
|
be662b913c | ||
|
|
e7ddc8448d | ||
|
|
29374f8d8a | ||
|
|
359b971103 | ||
|
|
fbdb1ae208 | ||
|
|
22c13c1eff | ||
|
|
5fc63aeaf1 | ||
|
|
d4f32673ab | ||
|
|
480dffb51b | ||
|
|
966df00124 | ||
|
|
3e2b4bc727 | ||
|
|
5929a8d42b | ||
|
|
f8ab40eb39 | ||
|
|
55e9233b93 | ||
|
|
b7277b51fd | ||
|
|
1fa9111b2b | ||
|
|
90a9e496d9 | ||
|
|
2a7dce1eb0 | ||
|
|
0c0841cc03 | ||
|
|
4c9fe016bf | ||
|
|
acc90f140c | ||
|
|
68a7bc3930 | ||
|
|
12ea64be0e | ||
|
|
7f30a673f7 | ||
|
|
897e100c32 | ||
|
|
0d4ad5cb31 | ||
|
|
b124bd0d0e | ||
|
|
6bc2f84602 | ||
|
|
d787a28c40 | ||
|
|
6b078a5731 | ||
|
|
17dddbfe21 | ||
|
|
3ff3c9e144 | ||
|
|
f5a37d82cc | ||
|
|
d3d428dc9d | ||
|
|
8dc8c5b5dc | ||
|
|
e6b06f914b | ||
|
|
4dc502a8b6 | ||
|
|
b1d1a13d5f | ||
|
|
75cc4cac5a | ||
|
|
1b7e4fbbdc | ||
|
|
9789e2f6c1 | ||
|
|
b8fb0bee24 | ||
|
|
419f77e245 | ||
|
|
59b1c3473b | ||
|
|
6db58ca375 | ||
|
|
4832b342b0 | ||
|
|
6cec542402 | ||
|
|
9644791783 | ||
|
|
5031c307d1 | ||
|
|
aa49539e3e | ||
|
|
7b4118493b | ||
|
|
d1cc9ba4ce | ||
|
|
e0e92139d7 | ||
|
|
62039392bb | ||
|
|
b72c69892e | ||
|
|
e6205e9aad | ||
|
|
b8a6fb1720 | ||
|
|
7c06d82f27 | ||
|
|
d92cb0f500 | ||
|
|
7fa72f2fe9 | ||
|
|
21d480a3b5 | ||
|
|
771c045844 | ||
|
|
e6ce484c15 | ||
|
|
102a92f62d | ||
|
|
6c7ac70701 | ||
|
|
9d8372289f | ||
|
|
766f6a1ba2 | ||
|
|
193ff24f4c | ||
|
|
c675017374 | ||
|
|
86cb852507 | ||
|
|
73494e0d7d | ||
|
|
ec61aa1b6f | ||
|
|
6df0e78b22 | ||
|
|
63c604359b | ||
|
|
08212588a0 | ||
|
|
c8518ce827 | ||
|
|
94434e3fc0 | ||
|
|
9f3af95198 | ||
|
|
acb3af8ab8 |
@@ -16,3 +16,5 @@ venv*/
|
||||
ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
dashboard/
|
||||
data/
|
||||
35
.github/workflows/auto_release.yml
vendored
Normal file
35
.github/workflows/auto_release.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
name: Auto Release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo ${{ github.ref_name }} > dist/assets/version
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Fetch Changelog
|
||||
run: |
|
||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
bodyFile: ${{ env.changelog }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
24
.github/workflows/coverage_test.yml
vendored
24
.github/workflows/coverage_test.yml
vendored
@@ -1,7 +1,14 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
push
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -21,17 +28,16 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
export LLM_MODEL=${{ secrets.LLM_MODEL }}
|
||||
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
|
||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
|
||||
7
.github/workflows/docker-image.yml
vendored
7
.github/workflows/docker-image.yml
vendored
@@ -1,8 +1,9 @@
|
||||
name: Docker Image CI/CD
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -35,7 +36,7 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -17,3 +17,12 @@ addons/plugins
|
||||
|
||||
tests/astrbot_plugin_openai
|
||||
chroma
|
||||
node_modules/
|
||||
.DS_Store
|
||||
package-lock.json
|
||||
package.json
|
||||
venv/*
|
||||
packages/python_interpreter/workplace
|
||||
.venv/*
|
||||
|
||||
.conda/
|
||||
@@ -12,7 +12,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install -r requirements.txt
|
||||
RUN python -m pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
172
README.md
172
README.md
@@ -1,46 +1,158 @@
|
||||
|
||||
<p align="center">
|
||||
<img width=200 src="https://github.com/user-attachments/assets/3dd6a669-0830-4db4-b821-c8b279ea19a6"/>
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<h1>AstrBot</h1>
|
||||
|
||||
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||
</a>
|
||||

|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
|
||||
<a href="https://astrbot.soulter.top/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
## ✨ 多消息平台部署
|
||||
## ✨ 主要功能
|
||||
|
||||
1. QQ 群、QQ 频道、微信、Telegram。
|
||||
2. 支持文本转图片,Markdown 渲染。
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
## ✨ 多 LLM 配置
|
||||
> [!TIP]
|
||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
|
||||
|
||||
1. 适配 OpenAI API,支持接入 Gemini、GPT、Llama、Claude、DeepSeek、GLM 等各种大语言模型。
|
||||
2. 支持 OneAPI 等分发平台。
|
||||
3. 支持 LLMTuner 载入微调模型。
|
||||
4. 支持 Ollama 载入自部署模型。
|
||||
4. 支持网页搜索(Web Search)。
|
||||
## ✨ 使用方式
|
||||
|
||||
## ✨ 管理面板
|
||||
#### Docker 部署
|
||||
|
||||
1. 支持可视化修改配置
|
||||
2. 日志实时查看
|
||||
3. 简单的信息统计
|
||||
4. 插件管理
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
||||
|
||||
#### 手动部署
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 飞书 | ✔ | 群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
|
||||
# 🦌 接下来的路线图
|
||||
|
||||
> [!TIP]
|
||||
> 欢迎在 Issue 提出更多建议 <3
|
||||
|
||||
- [ ] 完善并保证目前所有平台适配器的功能一致性
|
||||
- [ ] 优化插件接口
|
||||
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
|
||||
- [ ] 完善“聊天增强”部分,支持持久化记忆
|
||||
- [ ] 规划 i18n
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
对于新功能的添加,请先通过 Issue 讨论。
|
||||
|
||||
## 🌟 支持
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
> [!NOTE]
|
||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
|
||||
_✨ 自然语言待办事项 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ 插件系统——部分插件展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
||||
|
||||
_✨ 管理面板 ✨_
|
||||
|
||||

|
||||
|
||||
_✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||
3. Please ensure compliance with local laws and regulations when using this project.
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
@@ -51,23 +163,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
3. 表情包理解与回复
|
||||
4. TTS
|
||||
-->
|
||||
## ✨ 云部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
## ❤️ 贡献
|
||||
_私は、高性能ですから!_
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
对于新功能的添加,请先通过 Issue 讨论。
|
||||
|
||||
## 🔭 展望
|
||||
|
||||
1. 更强大的 Agent 系统。
|
||||
2. 打造插件工作流平台。
|
||||
|
||||
## ✨ Support
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
170
README_ja.md
Normal file
170
README_ja.md
Normal file
@@ -0,0 +1,170 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
|
||||
</div>
|
||||
|
||||
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
||||
|
||||
## ✨ 主な機能
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||
|
||||
> [!TIP]
|
||||
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||
|
||||
## ✨ 使用方法
|
||||
|
||||
#### Docker デプロイ
|
||||
|
||||
公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください。
|
||||
|
||||
#### Windows ワンクリックインストーラーのデプロイ
|
||||
|
||||
コンピュータに Python(>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
|
||||
|
||||
#### Replit デプロイ
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### CasaOS デプロイ
|
||||
|
||||
コミュニティが提供するデプロイ方法です。
|
||||
|
||||
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
|
||||
|
||||
#### 手動デプロイ
|
||||
|
||||
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
|
||||
|
||||
## ⚡ メッセージプラットフォームのサポート状況
|
||||
|
||||
| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
|
||||
| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||
| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
|
||||
| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
|
||||
| Feishu | ✔ | グループチャット | テキスト、画像 |
|
||||
| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
|
||||
| Discord | 🚧 | 計画中 | - |
|
||||
| WhatsApp | 🚧 | 計画中 | - |
|
||||
| Xiaoai 音響 | 🚧 | 計画中 | - |
|
||||
|
||||
# 🦌 今後のロードマップ
|
||||
|
||||
> [!TIP]
|
||||
> Issue でさらに多くの提案を歓迎します <3
|
||||
|
||||
- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
|
||||
- [ ] プラグインインターフェースの最適化
|
||||
- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
|
||||
- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
|
||||
- [ ] i18n の計画
|
||||
|
||||
## ❤️ 貢献
|
||||
|
||||
Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
|
||||
|
||||
新機能の追加については、まず Issue で議論してください。
|
||||
|
||||
## 🌟 サポート
|
||||
|
||||
- このプロジェクトに Star を付けてください!
|
||||
- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
|
||||
- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
|
||||
|
||||
## ✨ デモ
|
||||
|
||||
> [!NOTE]
|
||||
> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
_✨ Docker ベースのサンドボックス化されたコードエグゼキューター(ベータテスト中)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多モーダル、ウェブ検索、長文の画像変換(設定可能)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
|
||||
_✨ 自然言語タスク ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ プラグインシステム - 一部のプラグインの展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width="600">
|
||||
|
||||
_✨ 管理パネル ✨_
|
||||
|
||||

|
||||
|
||||
_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## スポンサー
|
||||
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
||||
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
この機能はプラグインとしてロードされます。プラグインリポジトリのアドレス:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
|
||||
1. 《ATRI ~ My Dear Moments》の主人公 ATRI のキャラクターセリフを微調整データセットとして使用した `Qwen1.5-7B-Chat Lora` 微調整モデル。
|
||||
2. 長期記憶
|
||||
3. ミームの理解と返信
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.provider.register import register_llm_tool as llm_tool
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"personalities",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"sp"
|
||||
]
|
||||
@@ -1,9 +1,8 @@
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.provider.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
# event
|
||||
from astrbot.core.message.message_event_result import (
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult, MessageChain, CommandResult, EventResultType
|
||||
MessageEventResult,
|
||||
MessageChain,
|
||||
CommandResult,
|
||||
EventResultType,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
__all__ = [
|
||||
'MessageEventResult', 'MessageChain', 'CommandResult', 'EventResultType', 'AstrMessageEvent'
|
||||
"MessageEventResult",
|
||||
"MessageChain",
|
||||
"CommandResult",
|
||||
"EventResultType",
|
||||
"AstrMessageEvent",
|
||||
"ResultContentType",
|
||||
]
|
||||
@@ -4,12 +4,19 @@ from astrbot.core.star.register import (
|
||||
register_event_message_type as event_message_type,
|
||||
register_regex as regex,
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
register_permission_type as permission_type
|
||||
register_permission_type as permission_type,
|
||||
register_custom_filter as custom_filter,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_on_llm_response as on_llm_response,
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
register_after_message_sent as after_message_sent
|
||||
)
|
||||
|
||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
|
||||
from astrbot.core.star.filter.custom_filter import CustomFilter
|
||||
|
||||
__all__ = [
|
||||
'command',
|
||||
@@ -23,5 +30,12 @@ __all__ = [
|
||||
'PlatformAdapterTypeFilter',
|
||||
'PlatformAdapterType',
|
||||
'PermissionTypeFilter',
|
||||
'CustomFilter',
|
||||
'custom_filter',
|
||||
'PermissionType',
|
||||
'on_llm_request',
|
||||
'llm_tool',
|
||||
'on_decorating_result',
|
||||
'after_message_sent',
|
||||
'on_llm_response'
|
||||
]
|
||||
@@ -3,3 +3,4 @@ from astrbot.core.platform import (
|
||||
)
|
||||
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
from astrbot.core.message.components import *
|
||||
@@ -1 +1,2 @@
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
|
||||
@@ -1,12 +1,26 @@
|
||||
import os
|
||||
import asyncio
|
||||
from .log import LogManager, LogBroker
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
html_renderer = HtmlRenderer()
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get('t2i_endpoint', 'https://t2i.soulter.top/text2img')
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
if os.environ.get('TESTING', ""):
|
||||
logger.setLevel('DEBUG')
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = SharedPreferences() # 简单的偏好设置存储
|
||||
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import json
|
||||
import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
@@ -13,29 +13,72 @@ class RateLimitStrategy(enum.Enum):
|
||||
DISCARD = "discard"
|
||||
|
||||
class AstrBotConfig(dict):
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
|
||||
|
||||
def __init__(self):
|
||||
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
|
||||
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
default_config: dict = DEFAULT_CONFIG,
|
||||
schema: dict = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
|
||||
object.__setattr__(self, 'config_path', config_path)
|
||||
object.__setattr__(self, 'default_config', default_config)
|
||||
object.__setattr__(self, 'schema', schema)
|
||||
|
||||
if schema:
|
||||
default_config = self._config_schema_to_default_config(schema)
|
||||
|
||||
if not self.check_exist():
|
||||
'''不存在时载入默认配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
|
||||
has_new = self.check_config_integrity(default_config, conf)
|
||||
self.update(conf)
|
||||
if has_new:
|
||||
self.save_config()
|
||||
|
||||
self.update(conf)
|
||||
|
||||
def _config_schema_to_default_config(self, schema: dict) -> dict:
|
||||
'''将 Schema 转换成 Config'''
|
||||
conf = {}
|
||||
|
||||
def _parse_schema(schema: dict, conf: dict):
|
||||
for k, v in schema.items():
|
||||
if v['type'] not in DEFAULT_VALUE_MAP:
|
||||
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
|
||||
if 'default' in v:
|
||||
default = v['default']
|
||||
else:
|
||||
default = DEFAULT_VALUE_MAP[v['type']]
|
||||
|
||||
if v['type'] == 'object':
|
||||
conf[k] = {}
|
||||
_parse_schema(v['items'], conf[k])
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
_parse_schema(schema, conf)
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
'''检查配置完整性,如果有新的配置项则返回 True'''
|
||||
has_new = False
|
||||
@@ -61,7 +104,7 @@ class AstrBotConfig(dict):
|
||||
'''
|
||||
if replace_config:
|
||||
self.update(replace_config)
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
with open(self.config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(self, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def __getattr__(self, item):
|
||||
@@ -81,4 +124,4 @@ class AstrBotConfig(dict):
|
||||
self[key] = value
|
||||
|
||||
def check_exist(self) -> bool:
|
||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
||||
return os.path.exists(self.config_path)
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.0"
|
||||
VERSION = "3.4.30"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -17,39 +17,84 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"reply_prefix": "",
|
||||
"forward_threshold": 200,
|
||||
"enable_id_white_list": True,
|
||||
"id_whitelist": [],
|
||||
"id_whitelist_log": True,
|
||||
"wl_ignore_admin_on_group": True,
|
||||
"wl_ignore_admin_on_friend": True,
|
||||
"reply_with_mention": False,
|
||||
"reply_with_quote": False,
|
||||
"path_mapping": [],
|
||||
"segmented_reply": {
|
||||
"enable": False,
|
||||
"only_llm_result": True,
|
||||
"interval_method": "random",
|
||||
"interval": "1.5,3.5",
|
||||
"log_base": 2.6,
|
||||
"words_count_threshold": 150,
|
||||
"regex": ".*?[。?!~…]+|.+$",
|
||||
"content_cleanup_rule": "",
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
"identifier": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"default_personality": "default",
|
||||
"prompt_prefix": "",
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_provider_id": "",
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
"possibility_reply": 0.1,
|
||||
"prompt": "",
|
||||
"whitelist": []
|
||||
}
|
||||
},
|
||||
"content_safety": {
|
||||
"also_use_in_response": False,
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||
},
|
||||
"admins_id": [],
|
||||
"admins_id": [
|
||||
"astrbot"
|
||||
],
|
||||
"t2i": False,
|
||||
"t2i_word_threshold": 150,
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "astrbot",
|
||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||
"port": 6185
|
||||
},
|
||||
"platform": [],
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": ""
|
||||
"plugin_repo_mirror": "",
|
||||
"knowledge_db": {},
|
||||
"persona": [],
|
||||
}
|
||||
|
||||
|
||||
@@ -71,20 +116,45 @@ CONFIG_METADATA_2 = {
|
||||
"enable_group_c2c": True,
|
||||
"enable_guild_direct_message": True,
|
||||
},
|
||||
"qq_official_webhook(QQ)": {
|
||||
"id": "default",
|
||||
"type": "qq_official_webhook",
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"port": 6196
|
||||
},
|
||||
"aiocqhtp(QQ)": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_host": "0.0.0.0",
|
||||
"ws_reverse_port": 6199,
|
||||
},
|
||||
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
"lark(飞书)": {
|
||||
"id": "lark",
|
||||
"type": "lark",
|
||||
"enable": False,
|
||||
"lark_bot_name": "",
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"domain": "https://open.feishu.cn"
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。",
|
||||
"hint": "用于在多实例下方便管理和识别。自定义,ID 不能重复。",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -126,6 +196,12 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"obvious_hint": True
|
||||
}
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -152,6 +228,58 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"no_permission_reply": {
|
||||
"description": "无权限回复",
|
||||
"type": "bool",
|
||||
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用分段回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"only_llm_result": {
|
||||
"description": "仅对 LLM 结果分段",
|
||||
"type": "bool",
|
||||
},
|
||||
"interval_method": {
|
||||
"description": "间隔时间计算方法",
|
||||
"type": "string",
|
||||
"options": ["random", "log"],
|
||||
"hint": "分段回复的间隔时间计算方法。random 为随机时间,log 为根据消息长度计算,$y=log_{log\_base}(x)$,x为字数,y的单位为秒。",
|
||||
},
|
||||
"interval": {
|
||||
"description": "随机间隔时间(秒)",
|
||||
"type": "string",
|
||||
"hint": "`random` 方法用。每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
|
||||
},
|
||||
"log_base": {
|
||||
"description": "对数函数底数",
|
||||
"type": "float",
|
||||
"hint": "`log` 方法用。对数函数的底数。默认为 2.6",
|
||||
},
|
||||
"words_count_threshold": {
|
||||
"description": "字数阈值",
|
||||
"type": "int",
|
||||
"hint": "超过这个字数的消息不会被分段回复。默认为 150",
|
||||
},
|
||||
"regex": {
|
||||
"description": "正则表达式",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||
},
|
||||
"content_cleanup_rule": {
|
||||
"description": "过滤分段后的内容",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"reply_prefix": {
|
||||
"description": "回复前缀",
|
||||
"type": "string",
|
||||
@@ -162,11 +290,16 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。",
|
||||
},
|
||||
"enable_id_white_list": {
|
||||
"description": "启用 ID 白名单",
|
||||
"type": "bool",
|
||||
},
|
||||
"id_whitelist": {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "int"},
|
||||
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
"description": "打印白名单日志",
|
||||
@@ -181,12 +314,34 @@ CONFIG_METADATA_2 = {
|
||||
"description": "管理员私聊消息无视 ID 白名单",
|
||||
"type": "bool",
|
||||
},
|
||||
"reply_with_mention": {
|
||||
"description": "回复时 @ 发送者",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。",
|
||||
},
|
||||
"reply_with_quote": {
|
||||
"description": "回复时引用消息",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。",
|
||||
},
|
||||
"path_mapping": {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"content_safety": {
|
||||
"description": "内容安全",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"also_use_in_response": {
|
||||
"description": "对大模型响应安全审核",
|
||||
"type": "bool",
|
||||
"hint": "启用后,大模型的响应也会通过内容安全审核。",
|
||||
},
|
||||
"baidu_aip": {
|
||||
"description": "百度内容审核配置",
|
||||
"type": "object",
|
||||
@@ -225,38 +380,86 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"provider_group": {
|
||||
"name": "大语言模型",
|
||||
"name": "服务提供商",
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"description": "大语言模型配置",
|
||||
"description": "服务提供商配置",
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"openai": {
|
||||
"id": "default",
|
||||
"id": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"azure_openai": {
|
||||
"id": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "grok-2-latest",
|
||||
},
|
||||
},
|
||||
"anthropic(claude)": {
|
||||
"id": "claude",
|
||||
"type": "anthropic_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"max_tokens": 4096,
|
||||
},
|
||||
},
|
||||
"ollama": {
|
||||
"id": "ollama_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434",
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
},
|
||||
"gemini": {
|
||||
"gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
},
|
||||
"gemini(googlegenai原生)": {
|
||||
"id": "gemini_default",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
@@ -267,20 +470,44 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek-chat",
|
||||
},
|
||||
},
|
||||
"zhipu": {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"type": "zhipu_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
},
|
||||
"siliconflow": {
|
||||
"id": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
},
|
||||
},
|
||||
"moonshot(kimi)": {
|
||||
"id": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {
|
||||
"model": "moonshot-v1-8k",
|
||||
},
|
||||
},
|
||||
"llmtuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
@@ -290,9 +517,116 @@ CONFIG_METADATA_2 = {
|
||||
"llmtuner_template": "",
|
||||
"finetuning_type": "lora",
|
||||
"quantization_bit": 4,
|
||||
}
|
||||
},
|
||||
"dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
"dify_api_base": "https://api.dify.ai/v1",
|
||||
"dify_workflow_output_key": "",
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"dashscope": {
|
||||
"id": "dashscope",
|
||||
"type": "dashscope",
|
||||
"enable": True,
|
||||
"dashscope_app_type": "agent",
|
||||
"dashscope_api_key": "",
|
||||
"dashscope_app_id": "",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"fastgpt": {
|
||||
"id": "fastgpt",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
},
|
||||
"whisper(API)": {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"openai_tts(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "tts-1",
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
},
|
||||
"fishaudio_tts(API)": {
|
||||
"id": "fishaudio_tts",
|
||||
"type": "fishaudio_tts_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "https://api.fish-audio.cn/v1",
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"timeout": "20",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"variables": {
|
||||
"description": "工作流固定输入变量",
|
||||
"type": "object",
|
||||
"obvious_hint": True,
|
||||
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
||||
},
|
||||
# "fastgpt_app_type": {
|
||||
# "description": "应用类型",
|
||||
# "type": "string",
|
||||
# "hint": "FastGPT 应用的应用类型。",
|
||||
# "options": ["agent", "workflow", "plugin"],
|
||||
# "obvious_hint": True,
|
||||
# },
|
||||
"dashscope_app_type": {
|
||||
"description": "应用类型",
|
||||
"type": "string",
|
||||
"hint": "阿里云百炼应用的应用类型。",
|
||||
"options": ["agent", "agent-arrange", "dialog-workflow", "task-workflow"],
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"timeout": {
|
||||
"description": "超时时间",
|
||||
"type": "int",
|
||||
"hint": "超时时间,单位为秒。",
|
||||
},
|
||||
"openai-tts-voice": {
|
||||
"description": "voice",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||
},
|
||||
"fishaudio-tts-character": {
|
||||
"description": "character",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
||||
},
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
@@ -317,7 +651,8 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
"description": "基座模型路径",
|
||||
@@ -360,7 +695,35 @@ CONFIG_METADATA_2 = {
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
"editable": True,
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
"hint": "Dify API Key。此项必填。",
|
||||
},
|
||||
"dify_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
|
||||
},
|
||||
"dify_api_type": {
|
||||
"description": "Dify 应用类型",
|
||||
"type": "string",
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
||||
"options": ["chat", "agent", "workflow"],
|
||||
},
|
||||
"dify_workflow_output_key": {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
"type": "string",
|
||||
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
||||
},
|
||||
"dify_query_input_key": {
|
||||
"description": "Prompt 输入变量名",
|
||||
"type": "string",
|
||||
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||
"obvious": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
@@ -370,7 +733,8 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "是否启用大语言模型聊天。默认启用",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
@@ -380,22 +744,31 @@ CONFIG_METADATA_2 = {
|
||||
"web_search": {
|
||||
"description": "启用网页搜索",
|
||||
"type": "bool",
|
||||
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
"obvious_hint": True,
|
||||
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
},
|
||||
"web_search_link": {
|
||||
"description": "网页搜索引用链接",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||
},
|
||||
"identifier": {
|
||||
"description": "启动识别群员",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"description": "启用日期时间系统提示",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||
},
|
||||
"default_personality": {
|
||||
"description": "默认人格",
|
||||
"description": "默认采用的人格情景的名称",
|
||||
"type": "string",
|
||||
"hint": "默认人格(情境设置/System Prompt)文本。",
|
||||
"hint": "",
|
||||
},
|
||||
"prompt_prefix": {
|
||||
"description": "Prompt 前缀文本",
|
||||
@@ -404,6 +777,151 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
"description": "人格情景设置",
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"新人格情景": {
|
||||
"name": "",
|
||||
"prompt": "",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
},
|
||||
"tmpl_display_title": "name",
|
||||
"items": {
|
||||
"name": {
|
||||
"description": "人格名称",
|
||||
"type": "string",
|
||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"prompt": {
|
||||
"description": "设定(系统提示词)",
|
||||
"type": "text",
|
||||
"hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
|
||||
},
|
||||
"begin_dialogs": {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"description": "语音转文本(STT)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用语音转文本(STT)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
"type": "string",
|
||||
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"description": "文本转语音(TTS)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用文本转语音(TTS)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"description": "聊天记忆增强(Beta)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious_hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "启用图像转述(需要模型支持)",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"image_caption_provider_id": {
|
||||
"description": "图像转述提供商 ID",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string",
|
||||
},
|
||||
"active_reply": {
|
||||
"description": "主动回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用主动回复",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
|
||||
},
|
||||
"whitelist": {
|
||||
"description": "主动回复白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
|
||||
},
|
||||
"method": {
|
||||
"description": "回复方法",
|
||||
"type": "string",
|
||||
"options": ["possibility_reply"],
|
||||
"hint": "回复方法。possibility_reply 为根据概率回复",
|
||||
},
|
||||
"possibility_reply": {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"obvious_hint": True,
|
||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"misc_config_group": {
|
||||
@@ -413,18 +931,24 @@ CONFIG_METADATA_2 = {
|
||||
"description": "机器人唤醒前缀",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
|
||||
"obvious_hint": True,
|
||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
||||
},
|
||||
"t2i": {
|
||||
"description": "文本转图像",
|
||||
"type": "bool",
|
||||
"hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。",
|
||||
},
|
||||
"t2i_word_threshold": {
|
||||
"description": "文本转图像字数阈值",
|
||||
"type": "int",
|
||||
"hint": "超出此字符长度的文本将会被转换成图片。字数不能低于 50。",
|
||||
},
|
||||
"admins_id": {
|
||||
"description": "管理员 ID",
|
||||
"type": "list",
|
||||
"items": {"type": "int"},
|
||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
|
||||
"items": {"type": "string"},
|
||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
|
||||
},
|
||||
"http_proxy": {
|
||||
"description": "HTTP 代理",
|
||||
@@ -450,7 +974,8 @@ CONFIG_METADATA_2 = {
|
||||
"plugin_repo_mirror": {
|
||||
"description": "插件仓库镜像",
|
||||
"type": "string",
|
||||
"hint": "插件仓库的镜像地址,用于加速插件的下载。",
|
||||
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
|
||||
"obvious_hint": True,
|
||||
"options": [
|
||||
"default",
|
||||
"https://ghp.ci/",
|
||||
|
||||
119
astrbot/core/conversation_mgr.py
Normal file
119
astrbot/core/conversation_mgr.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
class ConversationManager():
|
||||
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
asyncio.create_task(self._periodic_save())
|
||||
|
||||
async def _periodic_save(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
'''新建对话,并将当前会话的对话转移到新对话'''
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
'''切换会话的对话'''
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
|
||||
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id
|
||||
)
|
||||
del self.session_conversations[unified_msg_origin]
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
'''获取会话当前的对话 ID'''
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
|
||||
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
|
||||
'''获取会话的对话'''
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
'''获取会话的所有对话'''
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
|
||||
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
|
||||
'''更新会话的对话'''
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
history=json.dumps(history)
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
'''更新会话的对话标题'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
title=title
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
|
||||
'''更新会话的对话 Persona ID'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
persona_id=persona_id
|
||||
)
|
||||
|
||||
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
|
||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||
history = json.loads(conversation.history)
|
||||
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in history:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
@@ -1,11 +1,12 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import threading
|
||||
import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -16,20 +17,24 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
self.astrbot_config = AstrBotConfig()
|
||||
self.astrbot_config = astrbot_config
|
||||
self.db = db
|
||||
|
||||
if self.astrbot_config['http_proxy']:
|
||||
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['no_proxy'] = 'localhost,127.0.0.1'
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("AstrBot v"+ VERSION)
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
self.event_queue = Queue()
|
||||
self.event_queue.closed = False
|
||||
|
||||
@@ -37,12 +42,22 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
|
||||
self.star_context.platform_manager = self.platform_manager
|
||||
self.star_context.provider_manager = self.provider_manager
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager
|
||||
)
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
self.plugin_manager.reload()
|
||||
await self.plugin_manager.reload()
|
||||
'''扫描、注册插件、实例化插件类'''
|
||||
|
||||
await self.provider_manager.initialize()
|
||||
@@ -69,12 +84,30 @@ class AstrBotCoreLifecycle:
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
|
||||
self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
|
||||
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
for task in tasks_:
|
||||
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
|
||||
|
||||
self.start_time = int(time.time())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
async def start(self):
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def stop(self):
|
||||
@@ -82,6 +115,8 @@ class AstrBotCoreLifecycle:
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
|
||||
for task in self.curr_tasks:
|
||||
try:
|
||||
await task
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
@@ -77,3 +77,37 @@ class BaseDatabase(abc.ABC):
|
||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
||||
'''通过 url 或 path 获取 ATRI 视觉数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
'''通过 user_id 和 cid 获取 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
'''删除 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
'''更新 Conversation 标题'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
'''更新 Conversation Persona ID'''
|
||||
raise NotImplementedError
|
||||
@@ -33,16 +33,16 @@ class Stats():
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
|
||||
@dataclass
|
||||
class LLMHistory():
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
|
||||
@dataclass
|
||||
class ATRIVision():
|
||||
'''Deprecated'''
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
@@ -52,3 +52,19 @@ class ATRIVision():
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
|
||||
@dataclass
|
||||
class Conversation():
|
||||
'''LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
'''
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
'''字符串格式的列表。'''
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
|
||||
Platform,
|
||||
Stats,
|
||||
LLMHistory,
|
||||
ATRIVision
|
||||
ATRIVision,
|
||||
Conversation
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
@@ -25,6 +26,37 @@ class SQLiteDatabase(BaseDatabase):
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
'''
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
'''
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
has_persona_id = False
|
||||
for row in res:
|
||||
if row[1] == "title":
|
||||
has_title = True
|
||||
if row[1] == "persona_id":
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
'''
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
'''
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
'''
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
'''
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
c.close()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.text_factory = str
|
||||
@@ -201,6 +233,91 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
|
||||
if not res:
|
||||
return
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
''', (user_id, cid, history, updated_at, created_at)
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
''', (user_id,)
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
for row in res:
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
|
||||
return conversations
|
||||
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新对话,并且同时更新时间'''
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, updated_at, user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
''', (title, user_id, cid)
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
''', (persona_id, user_id, cid)
|
||||
)
|
||||
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
ts = int(time.time())
|
||||
keywords = ",".join(vision.keywords)
|
||||
|
||||
@@ -36,3 +36,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT,
|
||||
cid TEXT,
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
@@ -54,6 +54,7 @@ class ComponentType(Enum):
|
||||
CardImage = "CardImage"
|
||||
TTS = "TTS"
|
||||
Unknown = "Unknown"
|
||||
File = "File"
|
||||
|
||||
|
||||
class BaseMessageComponent(BaseModel):
|
||||
@@ -305,7 +306,7 @@ class Image(BaseMessageComponent):
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
id: int
|
||||
id: T.Union[str, int]
|
||||
text: T.Optional[str] = ""
|
||||
qq: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0
|
||||
@@ -324,11 +325,13 @@ class RedBag(BaseMessageComponent):
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: ComponentType = "Poke"
|
||||
qq: int
|
||||
type: str = ""
|
||||
id: T.Optional[int] = 0
|
||||
qq: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
def __init__(self, type: str, **_):
|
||||
type = f"Poke:{type}"
|
||||
super().__init__(type=type, **_)
|
||||
|
||||
|
||||
class Forward(BaseMessageComponent):
|
||||
@@ -338,14 +341,14 @@ class Forward(BaseMessageComponent):
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送
|
||||
class Node(BaseMessageComponent):
|
||||
'''群合并转发消息'''
|
||||
type: ComponentType = "Node"
|
||||
id: T.Optional[int] = 0
|
||||
name: T.Optional[str] = ""
|
||||
uin: T.Optional[int] = 0
|
||||
content: T.Optional[T.Union[str, list]] = ""
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[int] = 0 # qq号
|
||||
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, content: T.Union[str, list], **_):
|
||||
@@ -415,6 +418,17 @@ class Unknown(BaseMessageComponent):
|
||||
def toString(self):
|
||||
return ""
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
'''
|
||||
目前此消息段只适配了 Napcat。
|
||||
'''
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
|
||||
|
||||
ComponentTypes = {
|
||||
"plain": Plain,
|
||||
@@ -441,5 +455,6 @@ ComponentTypes = {
|
||||
"json": Json,
|
||||
"cardimage": CardImage,
|
||||
"tts": TTS,
|
||||
"unknown": Unknown
|
||||
"unknown": Unknown,
|
||||
'file': File,
|
||||
}
|
||||
|
||||
@@ -13,12 +13,10 @@ class MessageChain():
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
'''
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
def message(self, message: str):
|
||||
'''添加一条文本消息到消息链 `chain` 中。
|
||||
@@ -78,16 +76,6 @@ class MessageChain():
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def is_split(self, is_split: bool):
|
||||
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
Note:
|
||||
具体的效果以各适配器实现为准。
|
||||
|
||||
'''
|
||||
self.is_split_ = is_split
|
||||
return self
|
||||
|
||||
class EventResultType(enum.Enum):
|
||||
'''用于描述事件处理的结果类型。
|
||||
|
||||
@@ -98,6 +86,13 @@ class EventResultType(enum.Enum):
|
||||
CONTINUE = enum.auto()
|
||||
STOP = enum.auto()
|
||||
|
||||
class ResultContentType(enum.Enum):
|
||||
'''用于描述事件结果的内容的类型。
|
||||
'''
|
||||
LLM_RESULT = enum.auto()
|
||||
'''调用 LLM 产生的结果'''
|
||||
GENERAL_RESULT = enum.auto()
|
||||
'''普通的消息结果'''
|
||||
@dataclass
|
||||
class MessageEventResult(MessageChain):
|
||||
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
|
||||
@@ -106,12 +101,13 @@ class MessageEventResult(MessageChain):
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
`result_type` (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
|
||||
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
|
||||
|
||||
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
|
||||
|
||||
def stop_event(self) -> 'MessageEventResult':
|
||||
'''终止事件传播。
|
||||
'''
|
||||
@@ -130,5 +126,24 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
|
||||
'''设置事件处理的结果类型。
|
||||
|
||||
Args:
|
||||
result_type (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
self.result_content_type = typ
|
||||
return self
|
||||
|
||||
def is_llm_result(self) -> bool:
|
||||
'''是否为 LLM 结果。
|
||||
'''
|
||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
def get_plain_text(self) -> str:
|
||||
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
|
||||
'''
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -2,7 +2,9 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .respond.stage import RespondStage
|
||||
@@ -10,8 +12,9 @@ from .respond.stage import RespondStage
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"RateLimitCheckStage", # 检查会话是否超过频率限制
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage" # 发送消息
|
||||
@@ -20,7 +23,9 @@ STAGES_ORDER = [
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RespondStage",
|
||||
|
||||
@@ -17,11 +17,14 @@ class ContentSafetyCheckStage(Stage):
|
||||
config = ctx.astrbot_config['content_safety']
|
||||
self.strategy_selector = StrategySelector(config)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''检查内容安全'''
|
||||
ok, info = self.strategy_selector.check(event.get_message_str())
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
ok, info = self.strategy_selector.check(text)
|
||||
if not ok:
|
||||
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
|
||||
if event.is_at_or_wake_command:
|
||||
event.set_result(MessageEventResult().message("你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"))
|
||||
yield
|
||||
event.stop_event()
|
||||
logger.info(f"内容安全检查不通过,原因:{info}")
|
||||
return
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from . import ContentSafetyStrategy
|
||||
from typing import List, Tuple
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
class StrategySelector:
|
||||
def __init__(self, config: dict) -> None:
|
||||
@@ -15,7 +15,8 @@ class StrategySelector:
|
||||
try:
|
||||
from .baidu_aip import BaiduAipStrategy
|
||||
except ImportError:
|
||||
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
|
||||
logger.warning("使用百度内容审核应该先 pip install baidu-aip")
|
||||
return
|
||||
self.enabled_strategies.append(
|
||||
BaiduAipStrategy(
|
||||
config["baidu_aip"]["app_id"],
|
||||
|
||||
70
astrbot/core/pipeline/preprocess_stage/stage.py
Normal file
70
astrbot/core/pipeline/preprocess_stage/stage.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Record, Image
|
||||
|
||||
@register_stage
|
||||
class PreProcessStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
|
||||
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
|
||||
self.platform_settings: dict = self.config.get('platform_settings', {})
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''在处理事件之前的预处理'''
|
||||
# 路径映射
|
||||
if mappings := self.platform_settings.get('path_mapping', []):
|
||||
# 支持 Record,Image 消息段的路径映射。
|
||||
message_chain = event.get_messages()
|
||||
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, (Record, Image)) and component.url:
|
||||
for mapping in mappings:
|
||||
from_, to_ = mapping.split(":")
|
||||
from_ = from_.removesuffix("/")
|
||||
to_ = to_.removesuffix("/")
|
||||
|
||||
url = component.url.removeprefix("file://")
|
||||
if url.startswith(from_):
|
||||
component.url = url.replace(from_, to_, 1)
|
||||
logger.debug(f"路径映射: {url} -> {component.url}")
|
||||
message_chain[idx] = component
|
||||
|
||||
# STT
|
||||
if self.stt_settings.get('enable', False):
|
||||
# TODO: 独立
|
||||
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
@@ -1,90 +1,159 @@
|
||||
'''
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
|
||||
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
|
||||
self.ctx = ctx
|
||||
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
|
||||
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# Chat 唤醒前缀
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
return
|
||||
event.message_str = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
|
||||
|
||||
if self.prompt_prefix:
|
||||
event.message_str = self.prompt_prefix + event.message_str
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
event.message_str = user_info + event.message_str
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
image_urls = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
image_urls.append(image_url)
|
||||
|
||||
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.provider_wake_prefix:
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix):]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
|
||||
# 获取对话上下文
|
||||
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
|
||||
if not conversation_id:
|
||||
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
|
||||
req.session_id = event.unified_msg_origin
|
||||
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
try:
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=event.message_str,
|
||||
session_id=event.session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=tools
|
||||
)
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
if _nested:
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
|
||||
# 执行 LLM 响应后的事件。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, llm_response)
|
||||
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text))
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
||||
elif llm_response.role == 'err':
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
|
||||
elif llm_response.role == 'tool':
|
||||
# function calling
|
||||
function_calling_result = {}
|
||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
||||
func_tool = tools.get_func(func_tool_name)
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
try:
|
||||
# 尝试调用工具函数
|
||||
|
||||
star_cls_obj = star_map.get(func_tool.module_name).star_cls
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
if hasattr(func_tool.func_obj, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ret = await func_tool.func_obj(event, **func_tool_args)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
|
||||
else:
|
||||
ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args)
|
||||
|
||||
if ret:
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.stop_event()
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
function_calling_result[func_tool_name] = resp
|
||||
else:
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
|
||||
if function_calling_result:
|
||||
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
|
||||
# 我们重新执行一遍这个 stage
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
|
||||
for tool_name, tool_result in function_calling_result.items():
|
||||
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
||||
req.prompt += extra_prompt
|
||||
async for _ in self.process(event, _nested=True):
|
||||
yield
|
||||
event.clear_result() # 清除上一个 func tool 的结果
|
||||
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
if llm_response.completion_text:
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text))
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
|
||||
return
|
||||
|
||||
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts
|
||||
new_record = {
|
||||
"role": "user",
|
||||
"content": req.prompt
|
||||
}
|
||||
contexts.append(new_record)
|
||||
contexts.append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=contexts_to_save
|
||||
)
|
||||
@@ -1,13 +1,16 @@
|
||||
'''
|
||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||
'''
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
class StarRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
@@ -24,59 +27,23 @@ class StarRequestSubStage(Stage):
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_str not in star_map:
|
||||
if handler.handler_module_path not in star_map:
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
|
||||
|
||||
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
if hasattr(handler.handler, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ready_to_call = handler.handler(event, **params)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params)
|
||||
else:
|
||||
ready_to_call = handler.handler(star_cls_obj, event, **params)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for mer in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if mer:
|
||||
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(mer)
|
||||
yield
|
||||
else:
|
||||
if event.get_result():
|
||||
yield
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if ret:
|
||||
# 如果有返回值
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
if event.is_stopped() and event.get_result():
|
||||
# 插件主动停止事件传播,并且有结果
|
||||
event.continue_event()
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
yield
|
||||
elif not event.is_stopped and not event.get_result():
|
||||
continue
|
||||
else:
|
||||
yield
|
||||
logger.debug(f"执行插件 handler {handler.handler_full_name}")
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
event.stop_event()
|
||||
@@ -5,6 +5,8 @@ from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
@register_stage
|
||||
class ProcessStage(Stage):
|
||||
@@ -23,14 +25,35 @@ class ProcessStage(Stage):
|
||||
'''处理事件
|
||||
'''
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
|
||||
# 有插件 Handler 被激活
|
||||
if activated_handlers:
|
||||
async for _ in self.star_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
if not event._has_send_oper:
|
||||
'''当没有发送操作'''
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
async for resp in self.star_request_sub_stage.process(event):
|
||||
# 生成器返回值处理
|
||||
if isinstance(resp, ProviderRequest):
|
||||
# Handler 的 LLM 请求
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
# 调用 LLM 相关请求
|
||||
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
return
|
||||
|
||||
if not event._has_send_oper and event.is_at_or_wake_command:
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
# 事件没有终止传播
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
@@ -61,11 +61,12 @@ class RateLimitStage(Stage):
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD:
|
||||
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
|
||||
|
||||
@@ -1,20 +1,97 @@
|
||||
import random
|
||||
import asyncio
|
||||
import math
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core import logger
|
||||
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.message.components import Plain, Reply, At
|
||||
@register_stage
|
||||
class RespondStage:
|
||||
class RespondStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
|
||||
# 分段回复
|
||||
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
|
||||
|
||||
self.interval_method = ctx.astrbot_config['platform_settings']['segmented_reply']['interval_method']
|
||||
self.log_base = float(ctx.astrbot_config['platform_settings']['segmented_reply']['log_base'])
|
||||
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
|
||||
interval_str_ls = interval_str.replace(" ", "").split(",")
|
||||
try:
|
||||
self.interval = [float(t) for t in interval_str_ls]
|
||||
except BaseException as e:
|
||||
logger.error(f'解析分段回复的间隔时间失败。{e}')
|
||||
self.interval = [1.5, 3.5]
|
||||
logger.info(f"分段回复间隔时间:{self.interval}")
|
||||
|
||||
async def _word_cnt(self, text: str) -> int:
|
||||
'''分段回复 统计字数'''
|
||||
if all(ord(c) < 128 for c in text):
|
||||
word_count = len(text.split())
|
||||
else:
|
||||
word_count = len([c for c in text if c.isalnum()])
|
||||
return word_count
|
||||
|
||||
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
||||
'''分段回复 计算间隔时间'''
|
||||
if self.interval_method == 'log':
|
||||
if isinstance(comp, Plain):
|
||||
wc = await self._word_cnt(comp.text)
|
||||
i = math.log(wc + 1, self.log_base)
|
||||
return random.uniform(i, i + 0.5)
|
||||
else:
|
||||
return random.uniform(1, 1.75)
|
||||
else:
|
||||
# random
|
||||
return random.uniform(self.interval[0], self.interval[1])
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event.send(result)
|
||||
await event._pre_send()
|
||||
|
||||
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
|
||||
decorated_comps = []
|
||||
if self.reply_with_mention:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, At):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
if self.reply_with_quote:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Reply):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
else:
|
||||
await event.send(result)
|
||||
await event._post_send()
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
event.clear_result()
|
||||
@@ -1,38 +1,143 @@
|
||||
import time
|
||||
import re
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@register_stage
|
||||
class ResultDecorateStage:
|
||||
class ResultDecorateStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold']
|
||||
try:
|
||||
self.t2i_word_threshold = int(self.t2i_word_threshold)
|
||||
if self.t2i_word_threshold < 50:
|
||||
self.t2i_word_threshold = 50
|
||||
except BaseException:
|
||||
self.t2i_word_threshold = 150
|
||||
|
||||
self.forward_threshold = ctx.astrbot_config['platform_settings']['forward_threshold']
|
||||
|
||||
# 分段回复
|
||||
self.words_count_threshold = int(ctx.astrbot_config['platform_settings']['segmented_reply']['words_count_threshold'])
|
||||
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
|
||||
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
|
||||
self.content_cleanup_rule = ctx.astrbot_config['platform_settings']['segmented_reply']['content_cleanup_rule']
|
||||
|
||||
|
||||
# exception
|
||||
self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response']
|
||||
self.content_safe_check_stage = None
|
||||
if self.content_safe_check_reply:
|
||||
for stage in registered_stages:
|
||||
if stage.__class__.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
# 回复时检查内容安全
|
||||
if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result():
|
||||
text = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text += comp.text
|
||||
async for _ in self.content_safe_check_stage.process(event, check_text=text):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
result.chain.insert(0, Plain(self.reply_prefix))
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
comp.text = self.reply_prefix + comp.text
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
if self.enable_segmented_reply:
|
||||
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
if len(comp.text) > self.words_count_threshold:
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
split_response = []
|
||||
for line in comp.text.split("\n"):
|
||||
split_response.extend(re.findall(self.regex, line))
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
for seg in split_response:
|
||||
if self.content_cleanup_rule:
|
||||
seg = re.sub(self.content_cleanup_rule, "", seg)
|
||||
if seg.strip():
|
||||
new_chain.append(Plain(seg))
|
||||
else:
|
||||
# 非 Plain 类型的消息段不分段
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
if self.ctx.astrbot_config['provider_tts_settings']['enable'] and result.is_llm_result():
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(Record(file=audio_path, url=audio_path))
|
||||
else:
|
||||
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain_str += "\n\n" + comp.text
|
||||
else:
|
||||
break
|
||||
if plain_str and len(plain_str) > 150:
|
||||
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
||||
render_start = time.time()
|
||||
try:
|
||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
||||
@@ -43,3 +148,33 @@ class ResultDecorateStage:
|
||||
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
|
||||
if url:
|
||||
result.chain = [Image.fromURL(url)]
|
||||
|
||||
# 触发转发消息
|
||||
has_forwarded = False
|
||||
if event.get_platform_name() == 'aiocqhttp':
|
||||
word_cnt = 0
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
word_cnt += len(comp.text)
|
||||
if word_cnt > self.forward_threshold:
|
||||
node = Node(
|
||||
uin=event.get_self_id(),
|
||||
name="AstrBot",
|
||||
content=[
|
||||
*result.chain
|
||||
]
|
||||
)
|
||||
result.chain = [node]
|
||||
has_forwarded = True
|
||||
|
||||
if not has_forwarded:
|
||||
# at 回复
|
||||
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
|
||||
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
|
||||
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
|
||||
result.chain[1].text = "\n" + result.chain[1].text
|
||||
|
||||
# 引用回复
|
||||
if self.reply_with_quote:
|
||||
if not any(isinstance(item, File) for item in result.chain):
|
||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||
|
||||
@@ -12,14 +12,14 @@ class PipelineScheduler():
|
||||
|
||||
async def initialize(self):
|
||||
for stage in registered_stages:
|
||||
logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i]
|
||||
logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
coro = stage.process(event)
|
||||
if isinstance(coro, AsyncGenerator):
|
||||
async for _ in coro:
|
||||
@@ -41,4 +41,8 @@ class PipelineScheduler():
|
||||
async def execute(self, event: AstrMessageEvent):
|
||||
'''执行 pipeline'''
|
||||
await self._process_stages(event)
|
||||
|
||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import List, AsyncGenerator, Union
|
||||
import inspect
|
||||
from astrbot.api import logger
|
||||
from typing import List, AsyncGenerator, Union, Awaitable
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
registered_stages: List[Stage] = []
|
||||
'''维护了所有已注册的 Stage 实现类'''
|
||||
@@ -29,4 +32,37 @@ class Stage(abc.ABC):
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
async def _call_handler(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
'''调用 Handler。'''
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as e:
|
||||
# 向下兼容
|
||||
logger.debug(str(e))
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for ret in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
@@ -2,11 +2,12 @@ from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
@@ -21,6 +22,9 @@ class WakingCheckStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"no_permission_reply", True
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -47,6 +51,7 @@ class WakingCheckStage(Stage):
|
||||
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
|
||||
break
|
||||
is_wake = True
|
||||
event.is_at_or_wake_command = True
|
||||
event.is_wake = True
|
||||
event.message_str = event.message_str[len(wake_prefix) :].strip()
|
||||
break
|
||||
@@ -60,54 +65,51 @@ class WakingCheckStage(Stage):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
wake_prefix = ""
|
||||
event.is_at_or_wake_command = True
|
||||
break
|
||||
# 检查是否是私聊
|
||||
if event.is_private_chat():
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
event.is_at_or_wake_command = True
|
||||
wake_prefix = ""
|
||||
|
||||
# 检查插件的 handler filter
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
for handler in star_handlers_registry:
|
||||
# filter 需要满足 AND 的逻辑关系
|
||||
passed = True
|
||||
child_command_handler_md = None
|
||||
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
if len(handler.event_filters) == 0:
|
||||
# 不可能有这种情况, 也不允许有这种情况
|
||||
continue
|
||||
|
||||
for filter in handler.event_filters:
|
||||
try:
|
||||
if isinstance(filter, CommandGroupFilter):
|
||||
"""如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata"""
|
||||
ok, child_command_handler_md = filter.filter(
|
||||
event, self.ctx.astrbot_config
|
||||
)
|
||||
if not ok:
|
||||
passed = False
|
||||
else:
|
||||
handler = child_command_handler_md # handler 覆盖
|
||||
break
|
||||
if isinstance(filter, PermissionTypeFilter):
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
permission_not_pass = True
|
||||
else:
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
passed = False
|
||||
break
|
||||
except Exception as e:
|
||||
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
|
||||
# yield
|
||||
await event.send(
|
||||
MessageEventResult().message(
|
||||
f"插件 {handler.handler_full_name} 报错:{e}"
|
||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||
)
|
||||
)
|
||||
event.stop_event()
|
||||
passed = False
|
||||
break
|
||||
|
||||
if passed:
|
||||
if permission_not_pass:
|
||||
if self.no_permission_reply:
|
||||
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
@@ -116,6 +118,7 @@ class WakingCheckStage(Stage):
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
|
||||
event.clear_extra()
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
|
||||
@@ -10,12 +10,25 @@ class WhitelistCheckStage(Stage):
|
||||
'''检查是否在群聊/私聊白名单
|
||||
'''
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
|
||||
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
|
||||
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
|
||||
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
|
||||
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if not self.enable_whitelist_check:
|
||||
# 白名单检查未启用
|
||||
return
|
||||
|
||||
if len(self.whitelist) == 0:
|
||||
# 白名单为空,不检查
|
||||
return
|
||||
|
||||
if event.get_platform_name() == 'webchat':
|
||||
# WebChat 豁免
|
||||
return
|
||||
|
||||
# 检查是否在白名单
|
||||
if self.wl_ignore_admin_on_group:
|
||||
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -7,6 +7,8 @@ from astrbot.core.platform.message_type import MessageType
|
||||
from typing import List, Union
|
||||
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
@@ -29,11 +31,19 @@ class AstrMessageEvent(abc.ABC):
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,):
|
||||
self.message_str = message_str
|
||||
'''纯文本的消息'''
|
||||
self.message_obj = message_obj
|
||||
'''消息对象,AstrBotMessage。带有完整的消息结构。'''
|
||||
self.platform_meta = platform_meta
|
||||
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
|
||||
self.session_id = session_id
|
||||
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
|
||||
self.role = "member"
|
||||
self.is_wake = False
|
||||
'''用户是否是管理员。如果是管理员,这里是 admin'''
|
||||
self.is_wake = False # 是否通过 WakingStage
|
||||
'''是否唤醒'''
|
||||
self.is_at_or_wake_command = False
|
||||
'''是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True,但是不会让这个属性置为 True)'''
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
@@ -41,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
session_id=session_id
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
|
||||
'''统一的消息来源字符串。格式为 platform_name:message_type:session_id'''
|
||||
self._result: MessageEventResult = None
|
||||
'''消息事件的结果'''
|
||||
|
||||
@@ -176,6 +186,15 @@ class AstrMessageEvent(abc.ABC):
|
||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
'''调度器会在执行 send() 前调用该方法'''
|
||||
pass
|
||||
|
||||
async def _post_send(self):
|
||||
'''调度器会在执行 send() 后调用该方法'''
|
||||
pass
|
||||
|
||||
|
||||
def set_result(self, result: Union[MessageEventResult, str]):
|
||||
'''设置消息事件的结果。
|
||||
|
||||
@@ -237,6 +256,8 @@ class AstrMessageEvent(abc.ABC):
|
||||
'''
|
||||
self._result = None
|
||||
|
||||
'''消息链相关'''
|
||||
|
||||
def make_result(self) -> MessageEventResult:
|
||||
'''
|
||||
创建一个空的消息事件结果。
|
||||
@@ -276,3 +297,39 @@ class AstrMessageEvent(abc.ABC):
|
||||
mer = MessageEventResult()
|
||||
mer.chain = chain
|
||||
return mer
|
||||
|
||||
'''LLM 请求相关'''
|
||||
|
||||
def request_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager = None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
contexts: List = [],
|
||||
system_prompt: str = "",
|
||||
conversation: Conversation = None
|
||||
) -> ProviderRequest:
|
||||
'''
|
||||
创建一个 LLM 请求。
|
||||
|
||||
Examples:
|
||||
```py
|
||||
yield event.request_llm(prompt="hi")
|
||||
```
|
||||
prompt: 提示词
|
||||
session_id: 已经过时,留空即可
|
||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。
|
||||
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
|
||||
'''
|
||||
return ProviderRequest(
|
||||
prompt = prompt,
|
||||
session_id = session_id,
|
||||
image_urls = image_urls,
|
||||
func_tool = func_tool_manager,
|
||||
contexts = contexts,
|
||||
system_prompt = system_prompt,
|
||||
conversation=conversation
|
||||
)
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
from asyncio import Queue
|
||||
from .register import platform_cls_map
|
||||
from astrbot.core import logger
|
||||
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
class PlatformManager():
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
@@ -15,16 +15,25 @@ class PlatformManager():
|
||||
self.settings = config['platform_settings']
|
||||
self.event_queue = event_queue
|
||||
|
||||
for platform in self.platforms_config:
|
||||
if not platform['enable']:
|
||||
continue
|
||||
match platform['type']:
|
||||
case "aiocqhttp":
|
||||
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
|
||||
case "qq_official":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
try:
|
||||
for platform in self.platforms_config:
|
||||
if not platform['enable']:
|
||||
continue
|
||||
match platform['type']:
|
||||
case "aiocqhttp":
|
||||
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
|
||||
case "qq_official":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "qq_official_webhook":
|
||||
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
|
||||
except Exception as e:
|
||||
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。")
|
||||
|
||||
async def initialize(self):
|
||||
for platform in self.platforms_config:
|
||||
@@ -34,9 +43,11 @@ class PlatformManager():
|
||||
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
cls_type = platform_cls_map[platform['type']]
|
||||
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
||||
logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
||||
inst = cls_type(platform, self.settings, self.event_queue)
|
||||
self.platform_insts.append(inst)
|
||||
|
||||
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
@@ -1,5 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
@dataclass
|
||||
class PlatformMetadata():
|
||||
name: str # 平台的名称
|
||||
description: str # 平台的描述
|
||||
name: str
|
||||
'''平台的名称'''
|
||||
description: str
|
||||
'''平台的描述'''
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
'''平台的默认配置模板'''
|
||||
adapter_display_name: str = None
|
||||
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''
|
||||
@@ -7,15 +7,34 @@ platform_registry: List[PlatformMetadata] = []
|
||||
platform_cls_map: Dict[str, Type] = {}
|
||||
'''维护了平台适配器名称和适配器类的映射'''
|
||||
|
||||
def register_platform_adapter(adapter_name: str, desc: str):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def register_platform_adapter(
|
||||
adapter_name: str,
|
||||
desc: str,
|
||||
default_config_tmpl: dict = None,
|
||||
adapter_display_name: str = None
|
||||
):
|
||||
'''用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
'''
|
||||
def decorator(cls):
|
||||
if adapter_name in platform_cls_map:
|
||||
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
|
||||
|
||||
# 添加必备选项
|
||||
if default_config_tmpl:
|
||||
if 'type' not in default_config_tmpl:
|
||||
default_config_tmpl['type'] = adapter_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
if 'id' not in default_config_tmpl:
|
||||
default_config_tmpl['id'] = adapter_name
|
||||
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
|
||||
@@ -20,27 +18,43 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d['type'] = 'text'
|
||||
if isinstance(segment, Image):
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
image_base64 = file_to_base64(segment.file[8:])
|
||||
bs64_data = file_to_base64(segment.file[8:])
|
||||
image_file_path = segment.file[8:]
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(segment.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
d['data']['file'] = image_base64
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
elif segment.file and segment.file.startswith("base64://"):
|
||||
bs64_data = segment.file
|
||||
else:
|
||||
bs64_data = file_to_base64(segment.file)
|
||||
d['data'] = {
|
||||
'file': bs64_data,
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d['data'] = {
|
||||
'qq': str(segment.qq) # 转换为字符串
|
||||
}
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
return
|
||||
|
||||
if message.is_split_: # 分条发送
|
||||
for m in ret:
|
||||
await self.bot.send(self.message_obj.raw_message, [m])
|
||||
await asyncio.sleep(random.uniform(0.75, 2.5))
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Music)):
|
||||
# 转发消息不能和普通消息混在一起发送
|
||||
send_one_by_one = True
|
||||
break
|
||||
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
await super().send(message)
|
||||
@@ -1,16 +1,20 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from .aiocqhttp_message_event import *
|
||||
from astrbot.api.message_components import *
|
||||
from .aiocqhttp_message_event import * # noqa: F403
|
||||
from astrbot.api.message_components import * # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
|
||||
class AiocqhttpAdapter(Platform):
|
||||
@@ -42,20 +46,83 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
logger.debug(f"[aiocqhttp] RawMessage {event}")
|
||||
|
||||
if event['post_type'] == 'message':
|
||||
abm = await self._convert_handle_message_event(event)
|
||||
elif event['post_type'] == 'notice':
|
||||
abm = await self._convert_handle_notice_event(event)
|
||||
elif event['post_type'] == 'request':
|
||||
abm = await self._convert_handle_request_event(event)
|
||||
|
||||
return abm
|
||||
|
||||
async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
|
||||
'''OneBot V11 请求类事件'''
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.tag = "aiocqhttp"
|
||||
abm.sender = MessageMember(
|
||||
user_id=event.user_id,
|
||||
nickname=event.user_id
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if 'group_id' in event and event['group_id']:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.sender.user_id + "_" + str(event.group_id)
|
||||
abm.message_str = ''
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_id = uuid.uuid4().hex
|
||||
abm.raw_message = event
|
||||
return abm
|
||||
|
||||
async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
|
||||
'''OneBot V11 通知类事件'''
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=event.user_id,
|
||||
nickname=event.user_id
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if 'group_id' in event and event['group_id']:
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.raw_message = event
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_id = uuid.uuid4().hex
|
||||
|
||||
if 'sub_type' in event:
|
||||
if event['sub_type'] == 'poke' and 'target_id' in event:
|
||||
abm.message.append(Poke(qq=str(event['target_id']), type='poke')) # noqa: F405
|
||||
|
||||
return abm
|
||||
|
||||
|
||||
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
|
||||
'''OneBot V11 消息类事件'''
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname'])
|
||||
|
||||
if event['message_type'] == 'group':
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
elif event['message_type'] == 'private':
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
if self.unique_session:
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
|
||||
@@ -72,32 +139,82 @@ class AiocqhttpAdapter(Platform):
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return
|
||||
logger.debug(f"aiocqhttp: 收到消息: {event.message}")
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for m in event.message:
|
||||
t = m['type']
|
||||
a = None
|
||||
if t == 'text':
|
||||
message_str += m['data']['text'].strip()
|
||||
a = ComponentTypes[t](**m['data'])
|
||||
elif t == 'file':
|
||||
if m['data']['url'] and m['data']['url'].startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
|
||||
file_name = m['data'].get('file_name', "file")
|
||||
path = os.path.join("data/temp", file_name)
|
||||
await download_file(m['data']['url'], path)
|
||||
|
||||
m['data'] = {
|
||||
"file": path,
|
||||
"name": file_name
|
||||
}
|
||||
|
||||
else:
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
|
||||
if not ret.get('file', None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret['file']):
|
||||
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
|
||||
|
||||
m['data'] = {
|
||||
"file": ret['file'],
|
||||
"name": ret['file_name']
|
||||
}
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
a = ComponentTypes[t](**m['data']) # noqa: F405
|
||||
abm.message.append(a)
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = event
|
||||
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
if not self.host or not self.port:
|
||||
return
|
||||
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199")
|
||||
self.host = "0.0.0.0"
|
||||
self.port = 6199
|
||||
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
|
||||
@self.bot.on_request()
|
||||
async def request(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_notice()
|
||||
async def notice(event: Event):
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
|
||||
445
astrbot/core/platform/sources/gewechat/client.py
Normal file
445
astrbot/core/platform/sources/gewechat/client.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import threading
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import quart
|
||||
import base64
|
||||
import datetime
|
||||
import re
|
||||
import os
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api import logger, sp
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
|
||||
class SimpleGewechatClient():
|
||||
'''针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
'''
|
||||
def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith('/'):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
|
||||
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
||||
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
self.userrealnames = {}
|
||||
|
||||
async def get_token_id(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob['data']
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {
|
||||
"X-GEWE-TOKEN": self.token
|
||||
}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
type_name = data['TypeName']
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
if 'Data' in data and 'CreateTime' in data['Data']:
|
||||
# 得到系统 UTF+8 的 ts
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
ts = datetime.datetime.now(tz).timestamp()
|
||||
create_time = data['Data']['CreateTime']
|
||||
if create_time < ts - 30:
|
||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||
return
|
||||
|
||||
|
||||
abm = AstrBotMessage()
|
||||
d = data['Data']
|
||||
|
||||
from_user_name = d['FromUserName']['string'] # 消息来源
|
||||
d['to_wxid'] = from_user_name # 用于发信息
|
||||
|
||||
abm.message_id = str(d.get('MsgId'))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data['Wxid'] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d['Content']['string'] # 消息内容
|
||||
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(':\n')
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
if '\u2005' in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r'@[^\u2005]*\u2005', '', content)
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d['MsgSource']
|
||||
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
|
||||
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
|
||||
at_me = True
|
||||
if '在群聊中@了你' in d.get('PushContent', ''):
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
if abm.group_id:
|
||||
if abm.group_id not in self.userrealnames or user_id not in self.userrealnames[abm.group_id]:
|
||||
# 获取群成员列表,并且缓存
|
||||
if abm.group_id not in self.userrealnames:
|
||||
self.userrealnames[abm.group_id] = {}
|
||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||
if member_list and 'memberList' in member_list:
|
||||
for member in member_list['memberList']:
|
||||
self.userrealnames[abm.group_id][member['wxid']] = member['nickName']
|
||||
if user_id in self.userrealnames[abm.group_id]:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0]
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
# 不同消息类型
|
||||
match d['MsgType']:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid,
|
||||
content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
# data = await self.multimedia_downloader.download_voice(
|
||||
# self.appid,
|
||||
# content,
|
||||
# abm.message_id
|
||||
# )
|
||||
# print(data)
|
||||
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
|
||||
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
case _:
|
||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||
abm.raw_message = d
|
||||
|
||||
logger.debug(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get('testMsg', None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = None
|
||||
try:
|
||||
abm = await self._convert(data)
|
||||
except BaseException as e:
|
||||
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。")
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"token": self.token,
|
||||
"callbackUrl": self.callback_url
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。")
|
||||
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger_placeholder
|
||||
)
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while not self.event_queue.closed:
|
||||
await asyncio.sleep(1)
|
||||
logger.info("gewechat 适配器已关闭。")
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
# /login/checkOnline
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"appId": appid
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob['data']
|
||||
|
||||
async def logout(self):
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"appId": self.appid
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
|
||||
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
|
||||
payload = {
|
||||
"appId": self.appid
|
||||
}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob['data']['qrData']
|
||||
qr_uuid = json_blob['data']['uuid']
|
||||
appid = json_blob['data']['appId']
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}")
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({
|
||||
"uuid": qr_uuid,
|
||||
"appId": appid
|
||||
})
|
||||
verify_flag = False
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
if verify_flag or os.path.exists("data/temp/gewe_code"):
|
||||
with open("data/temp/gewe_code", "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning("未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
payload['captchCode'] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
except:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
|
||||
ret = json_blob['ret']
|
||||
msg = ''
|
||||
if json_blob['data'] and 'msg' in json_blob['data']:
|
||||
msg = json_blob['data']['msg']
|
||||
if ret == 500 and '安全验证码' in msg:
|
||||
logger.warning("此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456")
|
||||
verify_flag = True
|
||||
else:
|
||||
status = json_blob['data']['status']
|
||||
nickname = json_blob['data'].get('nickName', '')
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
'''API'''
|
||||
|
||||
async def get_chatroom_member_list(self, chatroom_wxid: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": chatroom_wxid
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob['data']
|
||||
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
if ats:
|
||||
payload['ats'] = ats
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送语音结果: {json_blob}")
|
||||
|
||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"fileUrl": file_url,
|
||||
"fileName": file_name
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postFile",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送文件结果: {json_blob}")
|
||||
51
astrbot/core/platform/sources/gewechat/downloader.py
Normal file
51
astrbot/core/platform/sources/gewechat/downloader.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
class GeweDownloader():
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-GEWE-TOKEN": token
|
||||
}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {
|
||||
"appId": appid,
|
||||
"xml": xml,
|
||||
"msgId": msg_id
|
||||
}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
'''返回一个可下载的 URL'''
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {
|
||||
"appId": appid,
|
||||
"xml": xml,
|
||||
"type": choice
|
||||
}
|
||||
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
|
||||
json_blob = json.loads(data)
|
||||
if 'fileUrl' in json_blob['data']:
|
||||
return self.download_base_url + json_blob['data']['fileUrl']
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
139
astrbot/core/platform/sources/gewechat/gewechat_event.py
Normal file
139
astrbot/core/platform/sources/gewechat/gewechat_event.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, File
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, 'rb') as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
elif n_frames == 0:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(message: MessageChain, user_name: str):
|
||||
pass
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get('to_wxid', None)
|
||||
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
# 检查@
|
||||
ats = []
|
||||
ats_names = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, At):
|
||||
ats.append(comp.qq)
|
||||
ats_names.append(comp.name)
|
||||
has_at = False
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text = comp.text
|
||||
payload = {
|
||||
"to_wxid": to_wxid,
|
||||
"content": text,
|
||||
}
|
||||
if not has_at and ats:
|
||||
ats = f"{','.join(ats)}"
|
||||
ats_names = f"@{' @'.join(ats_names)}"
|
||||
text = f"{ats_names} {text}"
|
||||
payload["content"] = text
|
||||
payload["ats"] = ats
|
||||
has_at = True
|
||||
await self.client.post_text(**payload)
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_url = comp.file
|
||||
img_path = ""
|
||||
if img_url.startswith("file:///"):
|
||||
img_path = img_url[8:]
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
img_path = await download_image_by_url(comp.file)
|
||||
else:
|
||||
img_path = img_url
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
temp_directory = os.path.abspath('data/temp')
|
||||
img_path = os.path.abspath(img_path)
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await self.client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = ""
|
||||
|
||||
if record_url.startswith("file:///"):
|
||||
record_path = record_url[8:]
|
||||
elif record_url.startswith("http"):
|
||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
||||
else:
|
||||
record_path = record_url
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
await self.send(MessageChain().message(f"语音文件转换失败。{str(e)}"))
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await self.client.post_voice(to_wxid, record_url, duration*1000)
|
||||
elif isinstance(comp, File):
|
||||
file_path = comp.file
|
||||
file_name = comp.name
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
file_id = os.path.basename(file_path)
|
||||
file_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await self.client.post_file(to_wxid, file_url, file_id)
|
||||
elif isinstance(comp, At):
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||
|
||||
await super().send(message)
|
||||
@@ -0,0 +1,89 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.message.components import Plain
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client = None
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
to_wxid = session.session_id
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"gewechat",
|
||||
"基于 gewechat 的 Wechat 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config['base_url'],
|
||||
self.config['nickname'],
|
||||
self.config['host'],
|
||||
self.config['port'],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
return self._run()
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
|
||||
await self.client.start_polling()
|
||||
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss['unique_session']:
|
||||
message.session_id = message.sender.user_id + "_" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
175
astrbot/core/platform/sources/lark/lark_adapter.py
Normal file
175
astrbot/core/platform/sources/lark/lark_adapter.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import base64
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .lark_event import LarkMessageEvent
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot import logger
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
|
||||
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
|
||||
class LarkPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
|
||||
self.appid = platform_config['app_id']
|
||||
self.appsecret = platform_config['app_secret']
|
||||
self.domain = platform_config.get('domain', lark.FEISHU_DOMAIN)
|
||||
self.bot_name = platform_config.get('lark_bot_name', "astrbot")
|
||||
|
||||
if not self.bot_name:
|
||||
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
await self.convert_msg(event)
|
||||
|
||||
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
asyncio.create_task(on_msg_event_recv(event))
|
||||
|
||||
self.event_handler = lark.EventDispatcherHandler.builder("", "") \
|
||||
.register_p2_im_message_receive_v1(do_v2_msg_event) \
|
||||
.build()
|
||||
|
||||
self.client = lark.ws.Client(
|
||||
app_id=self.appid,
|
||||
app_secret=self.appsecret,
|
||||
log_level=lark.LogLevel.ERROR,
|
||||
domain=self.domain,
|
||||
event_handler=self.event_handler
|
||||
)
|
||||
|
||||
self.lark_api = (
|
||||
lark.Client.builder()
|
||||
.app_id(self.appid)
|
||||
.app_secret(self.appsecret)
|
||||
.build()
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"lark",
|
||||
"飞书机器人官方 API 适配器",
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
message = event.event.message
|
||||
abm = AstrBotMessage()
|
||||
abm.timestamp = int(message.create_time) / 1000
|
||||
abm.message = []
|
||||
abm.type = MessageType.GROUP_MESSAGE if message.chat_type == 'group' else MessageType.FRIEND_MESSAGE
|
||||
if message.chat_type == 'group':
|
||||
abm.group_id = message.chat_id
|
||||
abm.self_id = self.bot_name
|
||||
abm.message_str = ""
|
||||
|
||||
at_list = {}
|
||||
if message.mentions:
|
||||
for m in message.mentions:
|
||||
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
|
||||
if m.name == self.bot_name:
|
||||
abm.self_id = m.id.open_id
|
||||
|
||||
content_json_b = json.loads(message.content)
|
||||
|
||||
if message.message_type == 'text':
|
||||
message_str_raw = content_json_b['text'] # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
parts = re.split(at_pattern, message_str_raw)
|
||||
for i in range(len(parts)):
|
||||
s = parts[i].strip()
|
||||
if not s:
|
||||
continue
|
||||
if s in at_list:
|
||||
abm.message.append(at_list[s])
|
||||
else:
|
||||
abm.message.append(Plain(parts[i].strip()))
|
||||
elif message.message_type == 'post':
|
||||
_ls = []
|
||||
|
||||
content_ls = content_json_b.get('content', [])
|
||||
for comp in content_ls:
|
||||
if isinstance(comp, list):
|
||||
_ls.extend(comp)
|
||||
elif isinstance(comp, dict):
|
||||
_ls.append(comp)
|
||||
content_json_b = _ls
|
||||
elif message.message_type == 'image':
|
||||
content_json_b = [
|
||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}
|
||||
]
|
||||
|
||||
if message.message_type in ('post', 'image'):
|
||||
for comp in content_json_b:
|
||||
if comp['tag'] == 'at':
|
||||
abm.message.append(at_list[comp['user_id']])
|
||||
elif comp['tag'] == 'text' and comp['text'].strip():
|
||||
abm.message.append(Plain(comp['text'].strip()))
|
||||
elif comp['tag'] == 'img':
|
||||
image_key = comp['image_key']
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message.message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
.build()
|
||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法下载飞书图片: {image_key}")
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
abm.message.append(Image.fromBase64(image_base64))
|
||||
|
||||
for comp in abm.message:
|
||||
if isinstance(comp, Plain):
|
||||
abm.message_str += comp.text
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
abm.sender = MessageMember(
|
||||
user_id=event.event.sender.sender_id.open_id,
|
||||
nickname=event.event.sender.sender_id.open_id[:8]
|
||||
)
|
||||
# 独立会话
|
||||
if not self.unique_session:
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
logger.debug(abm)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
event = LarkMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
bot=self.lark_api
|
||||
)
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
await self.client._connect()
|
||||
|
||||
96
astrbot/core/platform/sources/lark/lark_event.py
Normal file
96
astrbot/core/platform/sources/lark/lark_event.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
import uuid
|
||||
import lark_oapi as lark
|
||||
from typing import List
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, Record, At, Node, Music, Video
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from astrbot import logger
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id, bot: lark.Client):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List:
|
||||
ret = []
|
||||
_stage = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
_stage.append({
|
||||
"tag": "md",
|
||||
"text": comp.text
|
||||
})
|
||||
elif isinstance(comp, At):
|
||||
_stage.append({
|
||||
"tag": "at",
|
||||
"user_id": comp.qq,
|
||||
"style": []
|
||||
})
|
||||
elif isinstance(comp, AstrBotImage):
|
||||
file_path = ""
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace('file:///', '')
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
pass
|
||||
else:
|
||||
file_path = comp.file
|
||||
|
||||
request = CreateImageRequest.builder() \
|
||||
.request_body( \
|
||||
CreateImageRequestBody.builder() \
|
||||
.image_type("message") \
|
||||
.image(open(file_path, 'rb')) \
|
||||
.build() \
|
||||
) \
|
||||
.build()
|
||||
response = await lark_client.im.v1.image.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
image_key = response.data.image_key
|
||||
print(image_key)
|
||||
ret.append(_stage)
|
||||
ret.append([{
|
||||
"tag": "img",
|
||||
"image_key": image_key
|
||||
}])
|
||||
_stage.clear()
|
||||
else:
|
||||
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
|
||||
|
||||
if _stage:
|
||||
ret.append(_stage)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
}
|
||||
}
|
||||
|
||||
request = ReplyMessageRequest.builder() \
|
||||
.message_id(self.message_obj.message_id) \
|
||||
.request_body( \
|
||||
ReplyMessageRequestBody.builder() \
|
||||
.content(json.dumps(wrapped)) \
|
||||
.msg_type("post") \
|
||||
.uuid(str(uuid.uuid4())) \
|
||||
.reply_in_thread(False) \
|
||||
.build() \
|
||||
) \
|
||||
.build()
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send(message)
|
||||
@@ -8,18 +8,30 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
self.send_buffer = None
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not self.send_buffer:
|
||||
self.send_buffer = message
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def _post_send(self):
|
||||
'''QQ 官方 API 仅支持回复一次'''
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
|
||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
|
||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||
|
||||
if not plain_text and not image_base64 and not image_path:
|
||||
return
|
||||
|
||||
payload = {
|
||||
'content': plain_text,
|
||||
@@ -31,11 +43,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
|
||||
payload['media'] = media
|
||||
payload['msg_type'] = 7
|
||||
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
|
||||
case botpy.message.C2CMessage:
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
|
||||
payload['media'] = media
|
||||
payload['msg_type'] = 7
|
||||
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
|
||||
case botpy.message.Message:
|
||||
if image_path:
|
||||
@@ -46,7 +60,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
payload['file_image'] = image_path
|
||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
await super().send(message)
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
|
||||
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
|
||||
payload = {
|
||||
@@ -73,9 +89,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
plain_text += i.text
|
||||
elif isinstance(i, Image) and not image_base64:
|
||||
if i.file and i.file.startswith("file:///"):
|
||||
image_base64 = file_to_base64(i.file[8:])
|
||||
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
|
||||
image_file_path = i.file[8:]
|
||||
elif i.file and i.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(i.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
|
||||
else:
|
||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
||||
image_file_path = i.file
|
||||
else:
|
||||
logger.debug(f"qq_official 忽略 {i.type}")
|
||||
return plain_text, image_base64, image_file_path
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import botpy
|
||||
import logging
|
||||
import time
|
||||
@@ -11,7 +13,7 @@ from botpy import Client
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .qqofficial_message_event import QQOfficialMessageEvent
|
||||
from ...register import register_platform_adapter
|
||||
@@ -28,25 +30,25 @@ class botClient(Client):
|
||||
|
||||
# 收到群消息
|
||||
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
async def on_at_message_create(self, message: botpy.message.Message):
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到 C2C 消息
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id
|
||||
self._commit(abm)
|
||||
|
||||
@@ -102,7 +104,8 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
"QQ 机器人官方 API 适配器",
|
||||
)
|
||||
|
||||
def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage],
|
||||
@staticmethod
|
||||
def _parse_from_qqofficial(message: Union[botpy.message.Message, botpy.message.GroupMessage],
|
||||
message_type: MessageType):
|
||||
abm = AstrBotMessage()
|
||||
abm.type = message_type
|
||||
@@ -112,6 +115,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
abm.tag = "qq_official"
|
||||
msg: List[BaseMessageComponent] = []
|
||||
|
||||
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
|
||||
if isinstance(message, botpy.message.GroupMessage):
|
||||
abm.sender = MessageMember(
|
||||
@@ -126,7 +130,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message_str = message.content.strip()
|
||||
abm.self_id = "unknown_selfid"
|
||||
|
||||
msg.append(At(qq="qq_official"))
|
||||
msg.append(Plain(abm.message_str))
|
||||
if message.attachments:
|
||||
for i in message.attachments:
|
||||
@@ -146,7 +150,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
plain_content = message.content.replace(
|
||||
"<@!"+str(abm.self_id)+">", "").strip()
|
||||
msg.append(Plain(plain_content))
|
||||
|
||||
if message.attachments:
|
||||
for i in message.attachments:
|
||||
if i.content_type.startswith("image"):
|
||||
@@ -161,11 +165,14 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
str(message.author.id),
|
||||
str(message.author.username)
|
||||
)
|
||||
msg.append(At(qq="qq_official"))
|
||||
msg.append(Plain(plain_content))
|
||||
|
||||
if isinstance(message, botpy.message.Message):
|
||||
abm.group_id = message.channel_id
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {message_type}")
|
||||
abm.self_id = "qq_official"
|
||||
return abm
|
||||
|
||||
def run(self):
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
import botpy
|
||||
import logging
|
||||
import asyncio
|
||||
import botpy.message
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
|
||||
from botpy import Client
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .qo_webhook_event import QQOfficialWebhookMessageEvent
|
||||
from ...register import register_platform_adapter
|
||||
from .qo_webhook_server import QQOfficialWebhook
|
||||
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||
|
||||
# remove logger handler
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
# QQ 机器人官方框架
|
||||
class botClient(Client):
|
||||
def set_platform(self, platform: 'QQOfficialWebhookPlatformAdapter'):
|
||||
self.platform = platform
|
||||
|
||||
# 收到群消息
|
||||
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
async def on_at_message_create(self, message: botpy.message.Message):
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到 C2C 消息
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
abm.session_id = abm.sender.user_id
|
||||
self._commit(abm)
|
||||
|
||||
def _commit(self, abm: AstrBotMessage):
|
||||
self.platform.commit_event(QQOfficialWebhookMessageEvent(
|
||||
abm.message_str,
|
||||
abm,
|
||||
self.platform.meta(),
|
||||
abm.session_id,
|
||||
self
|
||||
))
|
||||
|
||||
@register_platform_adapter("qq_official_webhook", "QQ 机器人官方 API 适配器(Webhook)")
|
||||
class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
self.appid = platform_config['appid']
|
||||
self.secret = platform_config['secret']
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
|
||||
intents = botpy.Intents(
|
||||
public_messages=True,
|
||||
public_guild_messages=True,
|
||||
direct_message=True
|
||||
)
|
||||
self.client = botClient(
|
||||
intents=intents, # 已经无用
|
||||
bot_log=False,
|
||||
timeout=20,
|
||||
)
|
||||
self.client.set_platform(self)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"qq_official_webhook",
|
||||
"QQ 机器人官方 API 适配器",
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self.webhook_helper = QQOfficialWebhook(
|
||||
self.config,
|
||||
self._event_queue,
|
||||
self.client
|
||||
)
|
||||
await self.webhook_helper.initialize()
|
||||
await self.webhook_helper.start_polling()
|
||||
@@ -0,0 +1,18 @@
|
||||
import botpy
|
||||
import botpy.message
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Reply
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
|
||||
|
||||
|
||||
class QQOfficialWebhookMessageEvent(QQOfficialMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id, bot)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import aiohttp
|
||||
import quart
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import typing
|
||||
from botpy import BotAPI, BotHttp, Client, Token, BotWebSocket, ConnectionSession
|
||||
from astrbot.api import logger
|
||||
import traceback
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
# remove logger handler
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
class QQOfficialWebhook():
|
||||
def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client):
|
||||
self.appid = config['appid']
|
||||
self.secret = config['secret']
|
||||
self.port = config.get("port", 6196)
|
||||
|
||||
if isinstance(self.port, str):
|
||||
self.port = int(self.port)
|
||||
|
||||
self.http: BotHttp = BotHttp(timeout=300)
|
||||
self.api: BotAPI = BotAPI(http=self.http)
|
||||
self.token = Token(self.appid, self.secret)
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule('/astrbot-qo-webhook/callback', view_func=self.callback, methods=['POST'])
|
||||
self.client = botpy_client
|
||||
self.event_queue = event_queue
|
||||
|
||||
async def initialize(self):
|
||||
logger.info(f"正在登录到 QQ 官方机器人...")
|
||||
self.user = await self.http.login(self.token)
|
||||
logger.info(f"已登录 QQ 官方机器人账号: {self.user}")
|
||||
# 直接注入到 botpy 的 Client,移花接木!
|
||||
self.client.api = self.api
|
||||
self.client.http = self.http
|
||||
|
||||
async def bot_connect():
|
||||
pass
|
||||
|
||||
self._connection = ConnectionSession(
|
||||
max_async=1,
|
||||
connect=bot_connect,
|
||||
dispatch=self.client.ws_dispatch,
|
||||
loop=asyncio.get_event_loop(),
|
||||
api=self.api,
|
||||
)
|
||||
|
||||
async def repeat_seed(self, bot_secret: str, target_size: int = 32) -> bytes:
|
||||
seed = bot_secret
|
||||
while len(seed) < target_size:
|
||||
seed *= 2
|
||||
return seed[:target_size].encode('utf-8')
|
||||
|
||||
|
||||
async def webhook_validation(self, validation_payload: dict):
|
||||
seed = await self.repeat_seed(self.secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
msg = validation_payload.get("event_ts", "") + validation_payload.get("plain_token", "")
|
||||
# sign
|
||||
signature = private_key.sign(msg.encode()).hex()
|
||||
response = {
|
||||
"plain_token": validation_payload.get("plain_token"),
|
||||
"signature": signature
|
||||
}
|
||||
return response
|
||||
|
||||
async def callback(self):
|
||||
msg: dict = await quart.request.json
|
||||
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
|
||||
|
||||
event = msg.get("t")
|
||||
opcode = msg.get("op")
|
||||
data = msg.get("d")
|
||||
|
||||
if opcode == 13:
|
||||
# validation
|
||||
signed = await self.webhook_validation(data)
|
||||
print(signed)
|
||||
return signed
|
||||
|
||||
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
|
||||
event = msg["t"].lower()
|
||||
try:
|
||||
func = self._connection.parser[event]
|
||||
except KeyError:
|
||||
logger.error("_parser unknown event %s.", event)
|
||||
else:
|
||||
func(msg)
|
||||
|
||||
return {"opcode": 12}
|
||||
|
||||
async def start_polling(self):
|
||||
await self.server.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger_placeholder
|
||||
)
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while not self.event_queue.closed:
|
||||
await asyncio.sleep(1)
|
||||
logger.info("qq_official_webhook 适配器已关闭。")
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import random
|
||||
import asyncio
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from vchat import Core
|
||||
|
||||
class VChatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(client: Core, message: MessageChain, user_name: str):
|
||||
plain = ""
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
if message.is_split_:
|
||||
await client.send_msg(comp.text, user_name)
|
||||
else:
|
||||
plain += comp.text
|
||||
elif isinstance(comp, Image):
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
with open(file_path, "rb") as f:
|
||||
await client.send_image(user_name, fd=f)
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_path = await download_image_by_url(comp.file)
|
||||
with open(image_path, "rb") as f:
|
||||
await client.send_image(user_name, fd=f)
|
||||
else:
|
||||
logger.error(f"不支持的 vchat(微信适配器) 消息类型: {comp}")
|
||||
await asyncio.sleep(random.uniform(0.5, 1.5)) # 🤓
|
||||
|
||||
if plain:
|
||||
await client.send_msg(plain, user_name)
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
|
||||
await super().send(message)
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .vchat_message_event import VChatPlatformEvent
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
from vchat import Core
|
||||
from vchat import model
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器")
|
||||
class VChatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
from_username = session.session_id.split('$$')[0]
|
||||
await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"vchat",
|
||||
"基于 VChat 的 Wechat 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.client = Core()
|
||||
@self.client.msg_register(msg_types=model.ContentTypes.TEXT,
|
||||
contact_type=model.ContactTypes.CHATROOM | model.ContactTypes.USER)
|
||||
async def _(msg: model.Message):
|
||||
if isinstance(msg.content, model.UselessContent):
|
||||
return
|
||||
if msg.create_time < self.start_time:
|
||||
logger.debug(f"忽略旧消息: {msg}")
|
||||
return
|
||||
logger.debug(f"收到消息: {msg.todict()}")
|
||||
abmsg = self.convert_message(msg)
|
||||
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
|
||||
asyncio.create_task(self.handle_msg(abmsg))
|
||||
|
||||
# TODO: 对齐微信服务器时间
|
||||
self.start_time = int(time.time())
|
||||
return self._run()
|
||||
|
||||
|
||||
async def _run(self):
|
||||
await self.client.init()
|
||||
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
|
||||
await self.client.run()
|
||||
|
||||
def convert_message(self, msg: model.Message) -> AstrBotMessage:
|
||||
# credits: https://github.com/z2z63/astrbot_plugin_vchat/blob/master/main.py#L49
|
||||
assert isinstance(msg.content, model.TextContent)
|
||||
amsg = AstrBotMessage()
|
||||
amsg.message = [Plain(msg.content.content)]
|
||||
amsg.self_id = self.client_self_id
|
||||
if msg.content.is_at_me:
|
||||
amsg.message.insert(0, At(qq=amsg.self_id))
|
||||
|
||||
sender = msg.chatroom_sender or msg.from_
|
||||
amsg.sender = MessageMember(sender.username, sender.nickname)
|
||||
|
||||
if msg.content.is_at_me:
|
||||
amsg.message_str = msg.content.content.split("\u2005")[1].strip()
|
||||
else:
|
||||
amsg.message_str = msg.content.content
|
||||
amsg.message_id = msg.message_id
|
||||
if isinstance(msg.from_, model.User):
|
||||
amsg.type = MessageType.FRIEND_MESSAGE
|
||||
elif isinstance(msg.from_, model.Chatroom):
|
||||
amsg.type = MessageType.GROUP_MESSAGE
|
||||
amsg.group_id = msg.from_.username
|
||||
else:
|
||||
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
|
||||
|
||||
amsg.raw_message = msg
|
||||
|
||||
if self.settingss['unique_session']:
|
||||
session_id = msg.from_.username + "$$" + msg.to.username
|
||||
if msg.chatroom_sender is not None:
|
||||
session_id += '$$' + msg.chatroom_sender.username
|
||||
else:
|
||||
session_id = msg.from_.username
|
||||
|
||||
amsg.session_id = session_id
|
||||
return amsg
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = VChatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
logger.info(f"处理消息: {message_event}")
|
||||
|
||||
self.commit_event(message_event)
|
||||
112
astrbot/core/platform/sources/webchat/webchat_adapter.py
Normal file
112
astrbot/core/platform/sources/webchat/webchat_adapter.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import os
|
||||
from typing import Awaitable, Any
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
class QueueListener:
|
||||
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
|
||||
self.queue = queue
|
||||
self.callback = callback
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
await self.callback(data)
|
||||
|
||||
@register_platform_adapter("webchat", "webchat")
|
||||
class WebChatAdapter(Platform):
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"webchat",
|
||||
"webchat",
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
# abm.session_id = f"webchat!{username}!{cid}"
|
||||
plain = ""
|
||||
cid = session.session_id.split("!")[-1]
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain += comp.text
|
||||
web_chat_back_queue.put_nowait((plain, cid))
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
username, cid, payload = data
|
||||
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = "webchat"
|
||||
abm.tag = "webchat"
|
||||
abm.sender = MessageMember(username, username)
|
||||
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
abm.session_id = f"webchat!{username}!{cid}"
|
||||
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.message = []
|
||||
|
||||
if payload['message']:
|
||||
abm.message.append(Plain(payload['message']))
|
||||
if payload['image_url']:
|
||||
if isinstance(payload['image_url'], list):
|
||||
for img in payload['image_url']:
|
||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
|
||||
else:
|
||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
|
||||
if payload['audio_url']:
|
||||
if isinstance(payload['audio_url'], list):
|
||||
for audio in payload['audio_url']:
|
||||
path = os.path.join(self.imgs_dir, audio)
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
else:
|
||||
path = os.path.join(self.imgs_dir, payload['audio_url'])
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
|
||||
logger.debug(f"WebChatAdapter: {abm.message}")
|
||||
|
||||
message_str = payload['message']
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
async def callback(data: tuple):
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
bot = QueueListener(web_chat_queue, callback)
|
||||
return bot.run()
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
|
||||
message_event = WebChatMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
50
astrbot/core/platform/sources/webchat/webchat_event.py
Normal file
50
astrbot/core/platform/sources/webchat/webchat_event.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not message:
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
|
||||
cid = self.session_id.split("!")[-1]
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file.startswith("base64://"):
|
||||
base64_str = comp.file[9:]
|
||||
image_data = base64.b64decode(base64_str)
|
||||
with open(path, "wb") as f:
|
||||
f.write(image_data)
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
else:
|
||||
logger.debug(f"webchat 忽略: {comp.type}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -1,9 +1,10 @@
|
||||
from .provider import Provider, Personality
|
||||
from .provider import Provider, Personality, STTProvider
|
||||
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from .entites import ProviderMetaData
|
||||
|
||||
__all__ = [
|
||||
"Provider",
|
||||
"Personality",
|
||||
"ProviderMetaData",
|
||||
"STTProvider"
|
||||
]
|
||||
64
astrbot/core/provider/entites.py
Normal file
64
astrbot/core/provider/entites.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData():
|
||||
type: str
|
||||
'''提供商适配器名称,如 openai, ollama'''
|
||||
desc: str = ""
|
||||
'''提供商适配器描述.'''
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Type = None
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
'''平台的默认配置模板'''
|
||||
provider_display_name: str = None
|
||||
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest():
|
||||
prompt: str
|
||||
'''提示词'''
|
||||
session_id: str = ""
|
||||
'''会话 ID'''
|
||||
image_urls: List[str] = None
|
||||
'''图片 URL 列表'''
|
||||
func_tool: FuncCall = None
|
||||
'''工具'''
|
||||
contexts: List = None
|
||||
'''上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
'''
|
||||
system_prompt: str = ""
|
||||
'''系统提示词'''
|
||||
conversation: Conversation = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色, assistant, tool, err'''
|
||||
completion_text: str = ""
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
'''工具调用名称'''
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
@@ -1,25 +1,9 @@
|
||||
import json
|
||||
import textwrap
|
||||
from typing import Awaitable, Dict, List
|
||||
from typing import Dict, List, Awaitable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
|
||||
class FuncNotFoundError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
@@ -29,8 +13,8 @@ class FuncTool:
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
func_obj: Awaitable
|
||||
module_name: str = None
|
||||
handler: Awaitable
|
||||
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
|
||||
active: bool = True
|
||||
'''是否激活'''
|
||||
@@ -56,8 +40,7 @@ class FuncCall:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
func_obj: Awaitable,
|
||||
module_name: str = None,
|
||||
handler: Awaitable,
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
@@ -80,8 +63,7 @@ class FuncCall:
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
func_obj=func_obj,
|
||||
module_name=module_name,
|
||||
handler=handler,
|
||||
)
|
||||
self.func_list.append(_func)
|
||||
|
||||
@@ -120,9 +102,58 @@ class FuncCall:
|
||||
)
|
||||
return _l
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
"""
|
||||
获得 Anthropic API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
# Convert internal format to Anthropic style
|
||||
tool = {
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": f.parameters.get("properties", {}),
|
||||
# Keep the required field from the original parameters if it exists
|
||||
"required": f.parameters.get("required", [])
|
||||
}
|
||||
}
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> Dict:
|
||||
declarations = {}
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
func_declaration = {
|
||||
"name": f.name,
|
||||
"description": f.description
|
||||
}
|
||||
|
||||
# 检查并添加非空的properties参数
|
||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||
if params.get("properties", {}):
|
||||
func_declaration["parameters"] = params
|
||||
|
||||
tools.append(func_declaration)
|
||||
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"name": f["name"],
|
||||
@@ -179,12 +210,19 @@ class FuncCall:
|
||||
# 调用函数
|
||||
tool_callable = None
|
||||
for func in self.func_list:
|
||||
if func["name"] == func_name:
|
||||
tool_callable = func["func_obj"]
|
||||
if func.name == func_name:
|
||||
tool_callable = func.star_handler_metadata.handler
|
||||
break
|
||||
if not tool_callable:
|
||||
raise FuncNotFoundError(f"Request function {func_name} not found.")
|
||||
raise Exception(f"Request function {func_name} not found.")
|
||||
ret = await tool_callable(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色'''
|
||||
completion_text: str = None
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = None
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = None
|
||||
'''工具调用名称'''
|
||||
@@ -1,52 +1,244 @@
|
||||
import traceback
|
||||
import uuid
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .entites import ProviderType
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from collections import defaultdict
|
||||
from .register import provider_cls_map, llm_tools
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
class ProviderManager():
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
self.providers_config: List = config['provider']
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
|
||||
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
|
||||
self.persona_configs: list = config.get('persona', [])
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
for persona in self.persona_configs:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
bd_processed = []
|
||||
mid_processed = ""
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append({
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None # 不持久化到 db
|
||||
})
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
|
||||
mood_imitation_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in mood_imitation_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
mid_processed += '\n'
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed=mid_processed
|
||||
)
|
||||
if persona['name'] == self.default_persona_name:
|
||||
self.selected_default_persona = persona
|
||||
self.personas.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
if not self.selected_default_persona:
|
||||
self.selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed=""
|
||||
)
|
||||
self.personas.append(self.selected_default_persona)
|
||||
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
'''加载的 Speech To Text Provider 的实例'''
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
'''加载的 Text To Speech Provider 的实例'''
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
'''当前使用的 Provider 实例'''
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
'''当前使用的 Speech To Text Provider 实例'''
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
'''当前使用的 Text To Speech Provider 实例'''
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
self.curr_kdb_name = ""
|
||||
kdb_cfg = config.get("knowledge_db", {})
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
changed = False
|
||||
for provider_cfg in self.providers_config:
|
||||
if not provider_cfg['enable']:
|
||||
continue
|
||||
|
||||
if provider_cfg['id'] in self.loaded_ids:
|
||||
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
|
||||
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
|
||||
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}。")
|
||||
provider_cfg['id'] = new_id
|
||||
changed = True
|
||||
self.loaded_ids[provider_cfg['id']] = True
|
||||
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
try:
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "dashscope":
|
||||
from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
|
||||
continue
|
||||
|
||||
if changed:
|
||||
try:
|
||||
config.save_config()
|
||||
except Exception as e:
|
||||
logger.warning(f"保存配置文件失败:{e}")
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
for provider_config in self.providers_config:
|
||||
if not provider_config['enable']:
|
||||
continue
|
||||
if provider_config['type'] not in provider_cls_map:
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
cls_type = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
|
||||
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
self.provider_insts.append(inst)
|
||||
|
||||
if len(self.provider_insts) > 0:
|
||||
provider_metadata = provider_cls_map[provider_config['type']]
|
||||
logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
try:
|
||||
# 按任务实例化提供商
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.db_helper,
|
||||
self.provider_settings.get('persistant_history', True),
|
||||
self.selected_default_persona
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if selected_provider_id == provider_config['id'] and provider_enabled:
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
|
||||
|
||||
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
|
||||
if stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
|
||||
if tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
await provider_inst.terminate()
|
||||
@@ -5,12 +5,21 @@ from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core import logger
|
||||
from typing import TypedDict
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: List[str] = []
|
||||
mood_imitation_dialogs: List[str] = []
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: List[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta():
|
||||
@@ -19,38 +28,12 @@ class ProviderMeta():
|
||||
type: str
|
||||
|
||||
|
||||
class Provider(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
persistant_history: bool = True,
|
||||
db_helper: BaseDatabase = None
|
||||
) -> None:
|
||||
class AbstractProvider(abc.ABC):
|
||||
def __init__(self, provider_config: dict) -> None:
|
||||
super().__init__()
|
||||
self.model_name = ""
|
||||
'''当前使用的模型名称'''
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_config = provider_config
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality = Personality(prompt=provider_settings['default_personality'])
|
||||
'''维护了当前的使用的 persona,即人格。'''
|
||||
|
||||
self.db_helper = db_helper
|
||||
'''用于持久化的数据库操作对象。'''
|
||||
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
try:
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
@@ -59,6 +42,31 @@ class Provider(abc.ABC):
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
persistant_history: bool = True,
|
||||
db_helper: BaseDatabase = None,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
'''维护了当前的使用的 persona,即人格。可能为 None'''
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
@@ -76,22 +84,6 @@ class Provider(abc.ABC):
|
||||
'''获得支持的模型列表'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
|
||||
'''获取人类可读的上下文
|
||||
|
||||
page 从 1 开始
|
||||
|
||||
Example:
|
||||
|
||||
["User: 你好", "Assistant: 你好!"]
|
||||
|
||||
Return:
|
||||
contexts: List[str]: 上下文列表
|
||||
total_pages: int: 总页数
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
@@ -99,37 +91,62 @@ class Provider(abc.ABC):
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts: List=None,
|
||||
system_prompt: str=None,
|
||||
**kwargs) -> LLMResponse:
|
||||
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
session_id: 会话 ID
|
||||
session_id: 会话 ID(此属性已经被废弃)
|
||||
image_urls: 图片 URL 列表
|
||||
tools: Function-calling 工具
|
||||
contexts: 上下文
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
|
||||
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
- 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。
|
||||
传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
|
||||
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
async def pop_record(self, context: List):
|
||||
'''
|
||||
弹出 context 第一条非系统提示词对话记录
|
||||
'''
|
||||
poped = 0
|
||||
indexs_to_pop = []
|
||||
for idx, record in enumerate(context):
|
||||
if record["role"] == "system":
|
||||
continue
|
||||
else:
|
||||
indexs_to_pop.append(idx)
|
||||
poped += 1
|
||||
if poped == 2:
|
||||
break
|
||||
|
||||
for idx in reversed(indexs_to_pop):
|
||||
context.pop(idx)
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''获取音频的文本'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_audio(self, text: str) -> str:
|
||||
'''获取文本的音频,返回音频文件路径'''
|
||||
raise NotImplementedError()
|
||||
@@ -1,6 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData():
|
||||
type: str # 提供商适配器名称,如 openai, ollama
|
||||
desc: str = "" # 提供商适配器描述.
|
||||
@@ -1,69 +1,47 @@
|
||||
import docstring_parser
|
||||
from typing import List, Dict, Type, Awaitable
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from typing import List, Dict, Type
|
||||
from .entites import ProviderMetaData, ProviderType
|
||||
from astrbot.core import logger
|
||||
from .tool import FuncCall, SUPPORTED_TYPES
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
provider_registry: List[ProviderMetaData] = []
|
||||
'''维护了通过装饰器注册的 Provider'''
|
||||
provider_cls_map: Dict[str, Type] = {}
|
||||
'''维护了 Provider 类型名称和 Provider 类的映射'''
|
||||
provider_cls_map: Dict[str, ProviderMetaData] = {}
|
||||
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
|
||||
|
||||
llm_tools = FuncCall()
|
||||
|
||||
def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
def register_provider_adapter(
|
||||
provider_type_name: str,
|
||||
desc: str,
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
|
||||
default_config_tmpl: dict = None,
|
||||
provider_display_name: str = None
|
||||
):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def decorator(cls):
|
||||
if provider_type_name in provider_cls_map:
|
||||
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
|
||||
|
||||
# 添加必备选项
|
||||
if default_config_tmpl:
|
||||
if 'type' not in default_config_tmpl:
|
||||
default_config_tmpl['type'] = provider_type_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
if 'id' not in default_config_tmpl:
|
||||
default_config_tmpl['id'] = provider_type_name
|
||||
|
||||
pm = ProviderMetaData(
|
||||
type=provider_type_name,
|
||||
desc=desc,
|
||||
provider_type=provider_type,
|
||||
cls_type=cls,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
provider_display_name=provider_display_name
|
||||
)
|
||||
provider_registry.append(pm)
|
||||
provider_cls_map[provider_type_name] = cls
|
||||
logger.debug(f"Provider {provider_type_name} 已注册")
|
||||
provider_cls_map[provider_type_name] = pm
|
||||
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
|
||||
```
|
||||
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
|
||||
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
|
||||
\'\'\'获取天气信息。
|
||||
|
||||
Args:
|
||||
location(string): 地点
|
||||
\'\'\'
|
||||
# 处理逻辑
|
||||
```
|
||||
|
||||
可接受的参数类型有:string, number, object, array, boolean。
|
||||
'''
|
||||
name_ = name
|
||||
|
||||
def decorator(func_obj: Awaitable):
|
||||
llm_tool_name = name_ if name_ else func_obj.__name__
|
||||
module_name = func_obj.__module__
|
||||
docstring = docstring_parser.parse(func_obj.__doc__)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
|
||||
args.append({
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description
|
||||
})
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return func_obj
|
||||
|
||||
return decorator
|
||||
189
astrbot/core/provider/sources/anthropic_source.py
Normal file
189
astrbot/core/provider/sources/anthropic_source.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from typing import List
|
||||
from mimetypes import guess_type
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
@register_provider_adapter("anthropic_chat_completion", "Anthropic Claude API 提供商适配器")
|
||||
class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
# Skip OpenAI's __init__ and call Provider's __init__ directly
|
||||
Provider.__init__(self, provider_config, provider_settings, persistant_history, db_helper, default_persona)
|
||||
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
self.client = AsyncAnthropic(
|
||||
api_key=self.chosen_api_key,
|
||||
timeout=self.timeout,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_anthropic_style()
|
||||
if tool_list:
|
||||
payloads['tools'] = tool_list
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(completion, Message)
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
if len(completion.content) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
|
||||
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
|
||||
content = completion.content[-1]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if content.type == "text":
|
||||
# text completion
|
||||
completion_text = str(content.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
# Anthropic每次只返回一个函数调用
|
||||
if completion.stop_reason == "tool_use":
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
func_name_ls.append(content.name)
|
||||
args_ls.append(content.input)
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
|
||||
if not prompt:
|
||||
prompt = "<image>"
|
||||
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_config
|
||||
}
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
payloads['system'] = system_prompt
|
||||
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
response = await self.client.messages.create(
|
||||
messages=context_query,
|
||||
**model_config
|
||||
)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.completion_text = response.content[0].text
|
||||
llm_response.raw_completion = response
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''组装上下文,支持文本和图片'''
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
|
||||
content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_data.split("base64,")[1] if "base64," in image_data else image_data
|
||||
}
|
||||
})
|
||||
|
||||
return {"role": "user", "content": content}
|
||||
128
astrbot/core/provider/sources/dashscope_source.py
Normal file
128
astrbot/core/provider/sources/dashscope_source.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core import logger, sp
|
||||
from dashscope import Application
|
||||
|
||||
|
||||
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||
class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality = None,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||
if not self.app_id:
|
||||
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||
if not self.dashscope_app_type:
|
||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||
self.model_name = "dashscope"
|
||||
self.variables: dict = provider_config.get("variables", {})
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
|
||||
# 支持多轮对话的
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
if image_urls:
|
||||
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||
contexts_no_img = await self._remove_image_from_context(contexts)
|
||||
context_query = [*contexts_no_img, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
# 调用阿里云百炼 API
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
app_id=self.app_id,
|
||||
api_key=self.api_key,
|
||||
messages=context_query,
|
||||
biz_params=payload_vars or None,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
else:
|
||||
# 不支持多轮对话的
|
||||
# 调用阿里云百炼 API
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
app_id=self.app_id,
|
||||
promtp=prompt,
|
||||
api_key=self.api_key,
|
||||
biz_params=payload_vars or None,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
logger.debug(f"dashscope resp: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
|
||||
)
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "")
|
||||
return LLMResponse(role="assistant", completion_text=output_text)
|
||||
|
||||
async def forget(self, session_id):
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
pass
|
||||
152
astrbot/core/provider/sources/dify_source.py
Normal file
152
astrbot/core/provider/sources/dify_source.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper, default_persona
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Dify API Key 不能为空。")
|
||||
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_client = DifyAPIClient(self.api_key, api_base)
|
||||
self.api_type = provider_config.get("dify_api_type", "")
|
||||
if not self.api_type:
|
||||
raise Exception("Dify API 类型不能为空。")
|
||||
self.model_name = "dify"
|
||||
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
|
||||
self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query")
|
||||
self.variables: dict = provider_config.get("variables", {})
|
||||
if not self.dify_query_input_key:
|
||||
self.dify_query_input_key = "astrbot_text_query"
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.conversation_ids = {}
|
||||
'''记录当前 session id 的对话 ID'''
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
result = ""
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
file_response = await self.api_client.file_upload(image_path, user=session_id)
|
||||
if 'id' not in file_response:
|
||||
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
|
||||
continue
|
||||
files_payload.append({
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response['id'],
|
||||
})
|
||||
else:
|
||||
# TODO: 处理更多情况
|
||||
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={
|
||||
**payload_vars,
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk['event'] == "message" or \
|
||||
chunk['event'] == "agent_message":
|
||||
result += chunk['answer']
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk['conversation_id']
|
||||
conversation_id = chunk['conversation_id']
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
self.dify_query_input_key: prompt,
|
||||
"astrbot_session_id": session_id,
|
||||
**payload_vars,
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
match chunk['event']:
|
||||
case "workflow_started":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
|
||||
case "node_finished":
|
||||
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
|
||||
case "workflow_finished":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
|
||||
if chunk['data']['error']:
|
||||
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
if self.workflow_output_key not in chunk['data']['outputs']:
|
||||
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
|
||||
result = chunk['data']['outputs'][self.workflow_output_key]
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Dify 请求失败:{str(e)}")
|
||||
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}")
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=result)
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.conversation_ids[session_id] = ""
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("Dify 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 Dify 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
await self.api_client.close()
|
||||
105
astrbot/core/provider/sources/fishaudio_tts_api_source.py
Normal file
105
astrbot/core/provider/sources/fishaudio_tts_api_source.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import uuid
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
from httpx import AsyncClient
|
||||
from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
audio: bytes
|
||||
text: str
|
||||
|
||||
|
||||
class ServeTTSRequest(BaseModel):
|
||||
text: str
|
||||
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
||||
# 音频格式
|
||||
format: Literal["wav", "pcm", "mp3"] = "mp3"
|
||||
mp3_bitrate: Literal[64, 128, 192] = 128
|
||||
# 参考音频
|
||||
references: list[ServeReferenceAudio] = []
|
||||
# 参考模型 ID
|
||||
# 例如 https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
||||
# 其中reference_id为 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||
reference_id: str | None = None
|
||||
# 对中英文文本进行标准化,这可以提高数字的稳定性
|
||||
normalize: bool = True
|
||||
# 平衡模式将延迟减少到300毫秒,但可能会降低稳定性
|
||||
latency: Literal["normal", "balanced"] = "normal"
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"fishaudio_tts_api", "FishAudio TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.character: str = provider_config.get("fishaudio-tts-character", "可莉")
|
||||
self.api_base: str = provider_config.get(
|
||||
"api_base", "https://api.fish-audio.cn/v1"
|
||||
)
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str:
|
||||
"""
|
||||
获取角色的reference_id
|
||||
|
||||
Args:
|
||||
character: 角色名称
|
||||
|
||||
Returns:
|
||||
reference_id: 角色的reference_id
|
||||
|
||||
exception:
|
||||
APIException: 获取语音角色列表为空
|
||||
"""
|
||||
sort_options = ["score", "task_count", "created_at"]
|
||||
async with AsyncClient(base_url=self.api_base.replace("/v1", "")) as client:
|
||||
for sort_by in sort_options:
|
||||
params = {"title": character, "sort_by": sort_by}
|
||||
response = await client.get(
|
||||
"/model", params=params, headers=self.headers
|
||||
)
|
||||
resp_data = response.json()
|
||||
if resp_data["total"] == 0:
|
||||
continue
|
||||
for item in resp_data["items"]:
|
||||
if character in item["title"]:
|
||||
return item["_id"]
|
||||
return None
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
return ServeTTSRequest(
|
||||
text=text,
|
||||
format="wav",
|
||||
reference_id=await self._get_reference_id_by_character(self.character),
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
) as response:
|
||||
if response.headers["content-type"] == "audio/wav":
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
text = await response.aread()
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
284
astrbot/core/provider/sources/gemini_source.py
Normal file
284
astrbot/core/provider/sources/gemini_source.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import base64
|
||||
import aiohttp
|
||||
import random
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
class SimpleGoogleGenAIClient():
|
||||
def __init__(self, api_key: str, api_base: str, timeout: int=120) -> None:
|
||||
self.api_key = api_key
|
||||
if api_base.endswith("/"):
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
self.timeout = timeout
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
async with self.client.get(request_url, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
models = []
|
||||
for model in response["models"]:
|
||||
if 'generateContent' in model["supportedGenerationMethods"]:
|
||||
models.append(model["name"].replace("models/", ""))
|
||||
return models
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str="gemini-1.5-flash",
|
||||
system_instruction: str="",
|
||||
tools: dict=None
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {
|
||||
"parts": {"text": system_instruction}
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp:
|
||||
if "application/json" in resp.headers.get("Content-Type"):
|
||||
try:
|
||||
response = await resp.json()
|
||||
except Exception as e:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise e
|
||||
return response
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise Exception("Gemini 返回了非 json 数据: ")
|
||||
|
||||
|
||||
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
|
||||
class ProviderGoogleGenAI(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True,
|
||||
default_persona: Personality=None
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 180)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.client = SimpleGoogleGenAIClient(
|
||||
api_key=self.chosen_api_key,
|
||||
api_base=provider_config.get("api_base", None),
|
||||
timeout=self.timeout
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
if tools:
|
||||
tool = tools.get_func_desc_google_genai_style()
|
||||
if not tool:
|
||||
tool = None
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "system":
|
||||
system_instruction = message["content"]
|
||||
break
|
||||
|
||||
google_genai_conversation = []
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
if not message['content']:
|
||||
message['content'] = "<empty_content>"
|
||||
|
||||
google_genai_conversation.append({
|
||||
"role": "user",
|
||||
"parts": [{"text": message["content"]}]
|
||||
})
|
||||
elif isinstance(message["content"], list):
|
||||
# images
|
||||
parts = []
|
||||
for part in message["content"]:
|
||||
if part["type"] == "text":
|
||||
if not part["text"]:
|
||||
part["text"] = "<empty_content>"
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
parts.append({"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64
|
||||
}})
|
||||
google_genai_conversation.append({
|
||||
"role": "user",
|
||||
"parts": parts
|
||||
})
|
||||
|
||||
elif message["role"] == "assistant":
|
||||
if not message["content"]:
|
||||
message["content"] = "<empty_content>"
|
||||
google_genai_conversation.append({
|
||||
"role": "model",
|
||||
"parts": [{"text": message["content"]}]
|
||||
})
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
if "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
|
||||
candidates = result["candidates"][0]['content']['parts']
|
||||
llm_response = LLMResponse("assistant")
|
||||
for candidate in candidates:
|
||||
if 'text' in candidate:
|
||||
llm_response.completion_text += candidate['text']
|
||||
elif 'functionCall' in candidate:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args.append(candidate['functionCall']['args'])
|
||||
llm_response.tools_call_name.append(candidate['functionCall']['name'])
|
||||
|
||||
llm_response.completion_text = llm_response.completion_text.strip()
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config['model'] = self.get_model()
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_config
|
||||
}
|
||||
llm_response = None
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(keys)
|
||||
|
||||
for i in range(retry):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
|
||||
elif "Function calling is not enabled" in str(e):
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
llm_response = await self._query(payloads, None)
|
||||
elif "429" in str(e) or "API key not valid" in str(e):
|
||||
keys.remove(chosen_key)
|
||||
if len(keys) > 0:
|
||||
chosen_key = random.choice(keys)
|
||||
logger.info(f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}...")
|
||||
continue
|
||||
else:
|
||||
logger.error(f"A检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...")
|
||||
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
|
||||
else:
|
||||
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
组装上下文。
|
||||
'''
|
||||
if image_urls:
|
||||
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user","content": text}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
'''
|
||||
将图片转换为 base64
|
||||
'''
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ''
|
||||
@@ -2,81 +2,95 @@ import json
|
||||
import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
|
||||
|
||||
@register_provider_adapter(
|
||||
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
|
||||
)
|
||||
class LLMTunerModelLoader(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
persistant_history=True,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
self.base_model_path = provider_config['base_model_path']
|
||||
self.adapter_model_path = provider_config['adapter_model_path']
|
||||
self.model = ChatModel({
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config['llmtuner_template'],
|
||||
"finetuning_type": provider_config['finetuning_type'],
|
||||
"quantization_bit": provider_config['quantization_bit'],
|
||||
})
|
||||
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper, default_persona
|
||||
)
|
||||
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
||||
provider_config["adapter_model_path"]
|
||||
):
|
||||
raise FileNotFoundError("模型文件路径不存在。")
|
||||
self.base_model_path = provider_config["base_model_path"]
|
||||
self.adapter_model_path = provider_config["adapter_model_path"]
|
||||
self.model = ChatModel(
|
||||
{
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config["llmtuner_template"],
|
||||
"finetuning_type": provider_config["finetuning_type"],
|
||||
"quantization_bit": provider_config["quantization_bit"],
|
||||
}
|
||||
)
|
||||
self.set_model(
|
||||
os.path.basename(self.base_model_path)
|
||||
+ "_"
|
||||
+ os.path.basename(self.adapter_model_path)
|
||||
)
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
"""
|
||||
组装上下文。
|
||||
'''
|
||||
"""
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str] = None,
|
||||
tools = None,
|
||||
contexts: List=None,
|
||||
**kwargs) -> str:
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = [],
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
system_prompt = ""
|
||||
if not contexts:
|
||||
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
else:
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(contexts):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + contexts.pop(idx)["content"]
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
logger.debug(f"请求上下文:{contexts}")
|
||||
logger.debug(f"请求 System Prompt:{system_prompt}")
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(query_context):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
|
||||
if '_no_save' in context:
|
||||
del context['_no_save']
|
||||
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + query_context.pop(idx)["content"]
|
||||
|
||||
conf = {
|
||||
"messages": contexts,
|
||||
"messages": query_context,
|
||||
"system": system_prompt,
|
||||
}
|
||||
if tools:
|
||||
conf['tools'] = tools
|
||||
if func_tool:
|
||||
tool_list = func_tool.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
conf['tools'] = tool_list
|
||||
|
||||
responses = await self.model.achat(**conf)
|
||||
logger.debug(f"返回上下文:{responses}")
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
|
||||
self.session_memory[session_id].append({"role": "user", "content": prompt})
|
||||
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
|
||||
return responses[-1].response_text
|
||||
|
||||
async def forget(self, session_id):
|
||||
logger.info("llmtuner reset")
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
@@ -86,28 +100,3 @@ class LLMTunerModelLoader(Provider):
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
@@ -1,20 +1,19 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import NotFoundError
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
@@ -23,79 +22,53 @@ class ProviderOpenAIOfficial(Provider):
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
persistant_history = True,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.enable_datetime = provider_config.get("datetime_system_prompt", True)
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
# 适配 azure openai #332
|
||||
if "api_version" in provider_config:
|
||||
# 使用 azure api
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=self.timeout
|
||||
)
|
||||
else:
|
||||
# 使用 openai api
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
contexts.append(f"Assistant: {record['content']}")
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
model_config = provider_config.get("model_config", {})
|
||||
model = model_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
models = models.data
|
||||
models = sorted(models.data, key=lambda x: x.id)
|
||||
for model in models:
|
||||
models_str.append(model.id)
|
||||
return models_str
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
弹出第一条记录
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
payloads["tools"] = tools.get_func_desc_openai_style()
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
payloads['tools'] = tool_list
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
@@ -103,17 +76,20 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"completion: {completion.usage}")
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
return LLMResponse("assistant", completion_text)
|
||||
elif choice.message.tool_calls:
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
if choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
@@ -123,65 +99,135 @@ class ProviderOpenAIOfficial(Provider):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
|
||||
else:
|
||||
raise Exception("Internal Error")
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
if choice.finish_reason == 'content_filter':
|
||||
raise Exception("API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。")
|
||||
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
system_prompt = ""
|
||||
if self.curr_personality["prompt"]:
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
if self.enable_datetime:
|
||||
system_prompt += f"Current datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
else:
|
||||
context_query = contexts
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
logger.debug(f"请求上下文:{context_query}, {self.get_model()}")
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
}
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
llm_response.raw_completion = completion
|
||||
|
||||
return llm_response
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str=None,
|
||||
image_urls: List[str]=[],
|
||||
func_tool: FuncCall=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config['model'] = self.get_model()
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_config
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads['messages'] = new_contexts
|
||||
context_query = new_contexts
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
# 重试 10 次
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads['messages'] = new_contexts
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
elif 'does not support Function Calling' in str(e) \
|
||||
or 'does not support tools' in str(e) \
|
||||
or 'Function call is not supported' in str(e) \
|
||||
or 'Function calling is not enabled' in str(e) \
|
||||
or 'Tool calling is not supported' in str(e) \
|
||||
or 'No endpoints found that support tool use' in str(e) \
|
||||
or 'model does not support function calling' in str(e) \
|
||||
or ('tool' in str(e) and 'support' in str(e).lower()) \
|
||||
or ('function' in str(e) and 'support' in str(e).lower()):
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
llm_response = await self._query(payloads, None)
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if 'tool' in str(e).lower() and 'support' in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
if 'Connection error.' in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}")
|
||||
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
'''
|
||||
从上下文中删除所有带有 image 的记录
|
||||
'''
|
||||
new_contexts = []
|
||||
|
||||
flag = False
|
||||
for context in contexts:
|
||||
if flag:
|
||||
flag = False # 删除 image 后,下一条(LLM 响应)也要删除
|
||||
continue
|
||||
if isinstance(context['content'], list):
|
||||
flag = True
|
||||
# continue
|
||||
new_content = []
|
||||
for item in context['content']:
|
||||
if isinstance(item, dict) and 'image_url' in item:
|
||||
continue
|
||||
new_content.append(item)
|
||||
if not new_content:
|
||||
# 用户只发了图片
|
||||
new_content = [{"type": "text", "text": "[图片]"}]
|
||||
context['content'] = new_content
|
||||
new_contexts.append(context)
|
||||
return new_contexts
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
@@ -202,8 +248,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
|
||||
40
astrbot/core/provider/sources/openai_tts_api_source.py
Normal file
40
astrbot/core/provider/sources/openai_tts_api_source.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import uuid
|
||||
import os
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
|
||||
class ProviderOpenAITTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
self.voice = provider_config.get("openai-tts-voice", "alloy")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name,
|
||||
voice=self.voice,
|
||||
response_format='wav',
|
||||
input=text
|
||||
) as response:
|
||||
with open(path, 'wb') as f:
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
74
astrbot/core/provider/sources/whisper_api_source.py
Normal file
74
astrbot/core/provider/sources/whisper_api_source.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import uuid
|
||||
import os
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
filename = str(uuid.uuid4()) + '.mp3'
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
|
||||
is_tencent = False
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=open(audio_url, "rb"),
|
||||
)
|
||||
return result.text
|
||||
72
astrbot/core/provider/sources/whisper_selfhosted_source.py
Normal file
72
astrbot/core/provider/sources/whisper_selfhosted_source.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import uuid
|
||||
import os
|
||||
import asyncio
|
||||
import whisper
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.model = None
|
||||
|
||||
async def initialize(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
|
||||
logger.info("Whisper 模型加载完成。")
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
filename = str(uuid.uuid4()) + '.mp3'
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
is_tencent = False
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result['text']
|
||||
78
astrbot/core/provider/sources/zhipu_source.py
Normal file
78
astrbot/core/provider/sources/zhipu_source.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import traceback
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
|
||||
class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True,
|
||||
default_persona = None
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, db_helper, persistant_history, default_persona)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
model = self.get_model()
|
||||
# glm-4v-flash 只支持一张图片
|
||||
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
|
||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||
logger.debug(context_query)
|
||||
new_context_query_ = []
|
||||
for i in range(0, len(context_query) - 1, 2):
|
||||
if isinstance(context_query[i].get("content", ""), list):
|
||||
continue
|
||||
new_context_query_.append(context_query[i])
|
||||
new_context_query_.append(context_query[i+1])
|
||||
new_context_query_.append(context_query[-1]) # 保留最后一条记录
|
||||
context_query = new_context_query_
|
||||
logger.debug(context_query)
|
||||
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_cfgs
|
||||
}
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
25
astrbot/core/rag/embedding/openai_source.py
Normal file
25
astrbot/core/rag/embedding/openai_source.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
class SimpleOpenAIEmbedding():
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
'''
|
||||
获取文本的嵌入
|
||||
'''
|
||||
embedding = await self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return embedding.data[0].embedding
|
||||
92
astrbot/core/rag/knowledge_db_mgr.py
Normal file
92
astrbot/core/rag/knowledge_db_mgr.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
class KnowledgeDBManager():
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
|
||||
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
'''
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
'''
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]['strategy']:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
8
astrbot/core/rag/store/__init__.py
Normal file
8
astrbot/core/rag/store/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
class Store():
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
39
astrbot/core/rag/store/chroma_db.py
Normal file
39
astrbot/core/rag/store/chroma_db.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None)
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding,
|
||||
n_results=top_n,
|
||||
where=metadata_filter
|
||||
)
|
||||
return results['documents'][0]
|
||||
@@ -1,3 +1,7 @@
|
||||
'''
|
||||
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
|
||||
'''
|
||||
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, TypedDict, Union
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from .star import star_registry, StarMetadata
|
||||
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
|
||||
class StarCommand(TypedDict):
|
||||
full_command_name: str
|
||||
command_name: str
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
|
||||
class Context:
|
||||
'''
|
||||
@@ -38,43 +37,39 @@ class Context:
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
def __init__(self,
|
||||
event_queue: Queue,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
'''根据插件名获取插件的 Metadata'''
|
||||
for star in star_registry:
|
||||
if star.name == star_name:
|
||||
return star
|
||||
|
||||
def get_all_stars(self) -> List[StarMetadata]:
|
||||
'''获取当前载入的所有插件 Metadata 的列表'''
|
||||
return star_registry
|
||||
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
'''
|
||||
获取 LLM Tool Manager
|
||||
'''
|
||||
'''获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools'''
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 异步处理函数。
|
||||
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
|
||||
|
||||
@@ -83,7 +78,18 @@ class Context:
|
||||
'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -95,81 +101,66 @@ class Context:
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
注册一个命令。
|
||||
|
||||
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
|
||||
|
||||
@param star_name: 插件(Star)名称。
|
||||
@param command_name: 命令名称。
|
||||
@param desc: 命令描述。
|
||||
@param priority: 优先级。1-10。
|
||||
@param awaitable: 异步处理函数。
|
||||
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_str=awaitable.__module__,
|
||||
handler=awaitable,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
if use_regex:
|
||||
md.event_filters.append(RegexFilter(
|
||||
regex=command_name
|
||||
))
|
||||
else:
|
||||
md.event_filters.append(CommandFilter(
|
||||
command_name=command_name,
|
||||
handler_md=md
|
||||
))
|
||||
star_handlers_registry.append(md)
|
||||
star_handlers_map[md.handler_full_name] = md
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider。
|
||||
注册一个 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''
|
||||
通过 ID 获取 LLM Provider。
|
||||
'''
|
||||
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''
|
||||
获取所有 LLM Provider。
|
||||
'''
|
||||
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_all_tts_providers(self) -> List[TTSProvider]:
|
||||
'''获取所有用于 TTS 任务的 Provider。'''
|
||||
return self.provider_manager.tts_provider_insts
|
||||
|
||||
def get_all_stt_providers(self) -> List[STTProvider]:
|
||||
'''获取所有用于 STT 任务的 Provider。'''
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的 LLM Provider。
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_using_tts_provider(self) -> TTSProvider:
|
||||
'''
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
'''
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
|
||||
def get_using_stt_provider(self) -> STTProvider:
|
||||
'''
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
'''
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''
|
||||
获取 AstrBot 配置信息。
|
||||
'''
|
||||
'''获取 AstrBot 的配置。'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''
|
||||
获取 AstrBot 数据库。
|
||||
'''
|
||||
'''获取 AstrBot 数据库。'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
@@ -196,12 +187,77 @@ class Context:
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.registered_platforms:
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
'''
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
'''
|
||||
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 异步处理函数。
|
||||
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_path=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
star_handlers_registry.append(md)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
注册一个命令。
|
||||
|
||||
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
|
||||
|
||||
@param star_name: 插件(Star)名称。
|
||||
@param command_name: 命令名称。
|
||||
@param desc: 命令描述。
|
||||
@param priority: 优先级。1-10。
|
||||
@param awaitable: 异步处理函数。
|
||||
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
event_type=EventType.AdapterMessageEvent,
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_path=awaitable.__module__,
|
||||
handler=awaitable,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
if use_regex:
|
||||
md.event_filters.append(RegexFilter(
|
||||
regex=command_name
|
||||
))
|
||||
else:
|
||||
md.event_filters.append(CommandFilter(
|
||||
command_name=command_name,
|
||||
handler_md=md
|
||||
))
|
||||
star_handlers_registry.append(md)
|
||||
|
||||
def register_task(self, task: Awaitable, desc: str):
|
||||
'''
|
||||
注册一个异步任务。
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
|
||||
import re
|
||||
import inspect
|
||||
from typing import List, Any, Type, Dict
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
class CommandFilter(HandlerFilter):
|
||||
'''标准指令过滤器'''
|
||||
def __init__(self, command_name: str, handler_md: StarHandlerMetadata = None):
|
||||
def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None, parent_command_names: List[str] = [""]):
|
||||
self.command_name = command_name
|
||||
self.alias = alias if alias else set()
|
||||
self.parent_command_names = parent_command_names
|
||||
if handler_md:
|
||||
self.init_handler_md(handler_md)
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
|
||||
def print_types(self):
|
||||
result = ""
|
||||
@@ -22,6 +26,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
result += f"{k}({v.__name__}),"
|
||||
else:
|
||||
result += f"{k}({type(v).__name__})={v},"
|
||||
result = result.rstrip(",")
|
||||
return result
|
||||
|
||||
def init_handler_md(self, handle_md: StarHandlerMetadata):
|
||||
@@ -42,23 +47,79 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
def get_handler_md(self) -> StarHandlerMetadata:
|
||||
return self.handler_md
|
||||
|
||||
def add_custom_filter(self, custom_filter: CustomFilter):
|
||||
self.custom_filter_list.append(custom_filter)
|
||||
|
||||
def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
for custom_filter in self.custom_filter_list:
|
||||
if not custom_filter.filter(event, cfg):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
|
||||
'''将参数列表 params 根据 param_type 转换为参数字典。
|
||||
'''
|
||||
result = {}
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
if i >= len(params):
|
||||
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
|
||||
# 是类型
|
||||
raise ValueError(f"必要参数缺失。该指令完整参数: {self.print_types()}")
|
||||
else:
|
||||
# 是默认值
|
||||
result[param_name] = param_type_or_default_val
|
||||
else:
|
||||
# 尝试强制转换
|
||||
try:
|
||||
if param_type_or_default_val is None:
|
||||
if params[i].isdigit():
|
||||
result[param_name] = int(params[i])
|
||||
else:
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, int):
|
||||
result[param_name] = int(params[i])
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
raise ValueError(f"参数 {param_name} 类型错误。完整参数: {self.print_types()}")
|
||||
return result
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_wake_up():
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
|
||||
message_str = event.get_message_str().strip()
|
||||
# 分割为列表(每个参数之间可能会有多个空格)
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0]:
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False
|
||||
# params_str = message_str[len(self.command_name):].strip()
|
||||
ls = ls[1:]
|
||||
|
||||
# 检查是否以指令开头
|
||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||
candidates = [self.command_name] + list(self.alias)
|
||||
ok = False
|
||||
for candidate in candidates:
|
||||
for parent_command_name in self.parent_command_names:
|
||||
if parent_command_name:
|
||||
_full = f"{parent_command_name} {candidate}"
|
||||
else:
|
||||
_full = candidate
|
||||
if message_str.startswith(f"{_full} ") or message_str == _full:
|
||||
message_str = message_str[len(_full):].strip()
|
||||
ok = True
|
||||
break
|
||||
if not ok:
|
||||
return False
|
||||
|
||||
# 分割为列表
|
||||
ls = message_str.split(" ")
|
||||
# 去除空字符串
|
||||
ls = [param for param in ls if param]
|
||||
params = {}
|
||||
try:
|
||||
params = self.validate_and_convert_params(ls, self.handler_params)
|
||||
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -6,65 +6,97 @@ from . import HandlerFilter
|
||||
from .command import CommandFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
# 指令组受到 wake_prefix 的制约。
|
||||
class CommandGroupFilter(HandlerFilter):
|
||||
def __init__(self, group_name: str):
|
||||
def __init__(self, group_name: str, alias: set = None, parent_group: CommandGroupFilter = None):
|
||||
self.group_name = group_name
|
||||
self.alias = alias if alias else set()
|
||||
self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = []
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]):
|
||||
self.sub_command_filters.append(sub_command_filter)
|
||||
|
||||
def add_custom_filter(self, custom_filter: CustomFilter):
|
||||
self.custom_filter_list.append(custom_filter)
|
||||
|
||||
def get_complete_command_names(self) -> List[str]:
|
||||
'''遍历父节点获取完整的指令名。
|
||||
|
||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。'''
|
||||
parent_cmd_names = self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||
|
||||
if not parent_cmd_names:
|
||||
# 根节点
|
||||
return [self.group_name] + list(self.alias)
|
||||
|
||||
result = []
|
||||
candidates = [self.group_name] + list(self.alias)
|
||||
for parent_cmd_name in parent_cmd_names:
|
||||
for candidate in candidates:
|
||||
result.append(parent_cmd_name + " " + candidate)
|
||||
return result
|
||||
|
||||
|
||||
# 以树的形式打印出来
|
||||
def print_cmd_tree(self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "") -> str:
|
||||
def print_cmd_tree(self,
|
||||
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
||||
prefix: str = "",
|
||||
event: AstrMessageEvent = None,
|
||||
cfg: AstrBotConfig = None,
|
||||
) -> str:
|
||||
result = ""
|
||||
for sub_filter in sub_command_filters:
|
||||
if isinstance(sub_filter, CommandFilter):
|
||||
cmd_th = sub_filter.print_types()
|
||||
result += f"{prefix}├── {sub_filter.command_name}"
|
||||
if cmd_th:
|
||||
result += f" ({cmd_th})"
|
||||
else:
|
||||
result += " (无参数指令)"
|
||||
custom_filter_pass = True
|
||||
if event and cfg:
|
||||
custom_filter_pass = sub_filter.custom_filter_ok(event, cfg)
|
||||
if custom_filter_pass:
|
||||
cmd_th = sub_filter.print_types()
|
||||
result += f"{prefix}├── {sub_filter.command_name}"
|
||||
if cmd_th:
|
||||
result += f" ({cmd_th})"
|
||||
else:
|
||||
result += " (无参数指令)"
|
||||
|
||||
result += "\n"
|
||||
if sub_filter.handler_md and sub_filter.handler_md.desc:
|
||||
result += f": {sub_filter.handler_md.desc}"
|
||||
|
||||
result += "\n"
|
||||
elif isinstance(sub_filter, CommandGroupFilter):
|
||||
result += f"{prefix}├── {sub_filter.group_name}"
|
||||
result += "\n"
|
||||
result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ")
|
||||
custom_filter_pass = True
|
||||
if event and cfg:
|
||||
custom_filter_pass = sub_filter.custom_filter_ok(event, cfg)
|
||||
if custom_filter_pass:
|
||||
result += f"{prefix}├── {sub_filter.group_name}"
|
||||
result += "\n"
|
||||
result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ", event=event, cfg=cfg)
|
||||
|
||||
return result
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
|
||||
if not event.is_wake_up():
|
||||
return False, None
|
||||
def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
for custom_filter in self.custom_filter_list:
|
||||
if not custom_filter.filter(event, cfg):
|
||||
return False
|
||||
return True
|
||||
|
||||
message_str = event.get_message_str().strip()
|
||||
ls = re.split(r"\s+", message_str)
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
|
||||
if ls[0] != self.group_name:
|
||||
return False, None
|
||||
# 改写 message_str
|
||||
ls = ls[1:]
|
||||
event.message_str = " ".join(ls)
|
||||
event.message_str = event.message_str.strip()
|
||||
# 判断当前指令组的自定义过滤器
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False
|
||||
|
||||
if event.message_str == "":
|
||||
# 当前还是指令组
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
|
||||
complete_command_names = self.get_complete_command_names()
|
||||
if event.message_str.strip() in complete_command_names:
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
|
||||
|
||||
child_command_handler_md = None
|
||||
for sub_filter in self.sub_command_filters:
|
||||
if isinstance(sub_filter, CommandFilter):
|
||||
if sub_filter.filter(event, cfg):
|
||||
child_command_handler_md = sub_filter.get_handler_md()
|
||||
return True, child_command_handler_md
|
||||
elif isinstance(sub_filter, CommandGroupFilter):
|
||||
ok, handler = sub_filter.filter(event, cfg)
|
||||
if ok:
|
||||
child_command_handler_md = handler
|
||||
return True, child_command_handler_md
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
|
||||
raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree)
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
# return event.message_str.startswith(tuple(complete_command_names))
|
||||
return False
|
||||
53
astrbot/core/star/filter/custom_filter.py
Normal file
53
astrbot/core/star/filter/custom_filter.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from abc import abstractmethod, ABCMeta
|
||||
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
class CustomFilterMeta(ABCMeta):
|
||||
def __and__(cls, other):
|
||||
if not issubclass(other, CustomFilter):
|
||||
raise TypeError("Operands must be subclasses of CustomFilter.")
|
||||
return CustomFilterAnd(cls(), other())
|
||||
|
||||
def __or__(cls, other):
|
||||
if not issubclass(other, CustomFilter):
|
||||
raise TypeError("Operands must be subclasses of CustomFilter.")
|
||||
return CustomFilterOr(cls(), other())
|
||||
|
||||
class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta):
|
||||
def __init__(self, raise_error: bool = True, **kwargs):
|
||||
self.raise_error = raise_error
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
''' 一个用于重写的自定义Filter '''
|
||||
raise NotImplementedError
|
||||
|
||||
def __or__(self, other):
|
||||
return CustomFilterOr(self, other)
|
||||
|
||||
def __and__(self, other):
|
||||
return CustomFilterAnd(self, other)
|
||||
|
||||
class CustomFilterOr(CustomFilter):
|
||||
def __init__(self, filter1: CustomFilter, filter2: CustomFilter):
|
||||
super().__init__()
|
||||
if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)):
|
||||
raise ValueError("CustomFilter lass can only operate with other CustomFilter.")
|
||||
self.filter1 = filter1
|
||||
self.filter2 = filter2
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
return self.filter1.filter(event, cfg) or self.filter2.filter(event, cfg)
|
||||
|
||||
class CustomFilterAnd(CustomFilter):
|
||||
def __init__(self, filter1: CustomFilter, filter2: CustomFilter):
|
||||
super().__init__()
|
||||
if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)):
|
||||
raise ValueError("CustomFilter lass can only operate with other CustomFilter.")
|
||||
self.filter1 = filter1
|
||||
self.filter2 = filter2
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
return self.filter1.filter(event, cfg) and self.filter2.filter(event, cfg)
|
||||
@@ -6,8 +6,8 @@ from astrbot.core.config import AstrBotConfig
|
||||
class PermissionType(enum.Flag):
|
||||
'''权限类型。当选择 MEMBER,ADMIN 也可以通过。
|
||||
'''
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
ADMIN = enum.auto()
|
||||
MEMBER = enum.auto()
|
||||
|
||||
class PermissionTypeFilter(HandlerFilter):
|
||||
def __init__(self, permission_type: PermissionType, raise_error: bool = True):
|
||||
@@ -19,7 +19,8 @@ class PermissionTypeFilter(HandlerFilter):
|
||||
'''
|
||||
if self.permission_type == PermissionType.ADMIN:
|
||||
if not event.is_admin():
|
||||
event.stop_event()
|
||||
raise ValueError("您没有权限执行此操作。")
|
||||
# event.stop_event()
|
||||
# raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -8,12 +8,14 @@ class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
VCHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
|
||||
GEWECHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT
|
||||
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"vchat": PlatformAdapterType.VCHAT
|
||||
"vchat": PlatformAdapterType.VCHAT,
|
||||
"gewechat": PlatformAdapterType.GEWECHAT
|
||||
}
|
||||
|
||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.config import AstrBotConfig
|
||||
class RegexFilter(HandlerFilter):
|
||||
'''正则表达式过滤器'''
|
||||
def __init__(self, regex: str):
|
||||
self.regex_str = regex
|
||||
self.regex = re.compile(regex)
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
|
||||
@@ -5,7 +5,13 @@ from .star_handler import (
|
||||
register_event_message_type,
|
||||
register_platform_adapter_type,
|
||||
register_regex,
|
||||
register_permission_type
|
||||
register_permission_type,
|
||||
register_custom_filter,
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -15,5 +21,11 @@ __all__ = [
|
||||
'register_event_message_type',
|
||||
'register_platform_adapter_type',
|
||||
'register_regex',
|
||||
'register_permission_type'
|
||||
'register_permission_type',
|
||||
'register_custom_filter',
|
||||
'register_on_llm_request',
|
||||
'register_on_llm_response',
|
||||
'register_llm_tool',
|
||||
'register_on_decorating_result',
|
||||
'register_after_message_sent'
|
||||
]
|
||||
@@ -1,81 +1,156 @@
|
||||
from __future__ import annotations
|
||||
import docstring_parser
|
||||
|
||||
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
|
||||
from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from ..filter.command import CommandFilter
|
||||
from ..filter.command_group import CommandGroupFilter
|
||||
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
||||
from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
def get_handler_full_name(awatable: Awaitable) -> str:
|
||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||
'''获取 Handler 的全名'''
|
||||
return f"{awatable.__module__}_{awatable.__name__}"
|
||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||
|
||||
def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMetadata:
|
||||
def get_handler_or_create(
|
||||
handler: Awaitable,
|
||||
event_type: EventType,
|
||||
dont_add = False,
|
||||
**kwargs
|
||||
) -> StarHandlerMetadata:
|
||||
'''获取 Handler 或者创建一个新的 Handler'''
|
||||
handler_full_name = get_handler_full_name(handler)
|
||||
if handler_full_name in star_handlers_map:
|
||||
return star_handlers_map[handler_full_name]
|
||||
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
|
||||
if md:
|
||||
return md
|
||||
else:
|
||||
md = StarHandlerMetadata(
|
||||
event_type=event_type,
|
||||
handler_full_name=handler_full_name,
|
||||
handler_name=handler.__name__,
|
||||
handler_module_str=handler.__module__,
|
||||
handler_module_path=handler.__module__,
|
||||
handler=handler,
|
||||
event_filters=[]
|
||||
)
|
||||
|
||||
# 插件handler的附加额外信息
|
||||
if handler.__doc__:
|
||||
md.desc = handler.__doc__.strip()
|
||||
if 'desc' in kwargs:
|
||||
md.desc = kwargs['desc']
|
||||
del kwargs['desc']
|
||||
md.extras_configs = kwargs
|
||||
|
||||
if not dont_add:
|
||||
star_handlers_registry.append(md)
|
||||
star_handlers_map[handler_full_name] = md
|
||||
return md
|
||||
|
||||
def register_command(command_name: str = None, *args):
|
||||
'''注册一个 Command'''
|
||||
|
||||
def register_command(command_name: str = None, sub_command: str = None, alias: set = None, **kwargs):
|
||||
'''注册一个 Command.
|
||||
'''
|
||||
new_command = None
|
||||
add_to_event_filters = False
|
||||
if isinstance(command_name, RegisteringCommandable):
|
||||
# 子指令
|
||||
new_command = CommandFilter(args[0], None)
|
||||
parent_command_names = command_name.parent_group.get_complete_command_names()
|
||||
logger.debug(f"parent_command_names: {parent_command_names}")
|
||||
new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names)
|
||||
command_name.parent_group.add_sub_command_filter(new_command)
|
||||
else:
|
||||
# 裸指令
|
||||
new_command = CommandFilter(command_name, None)
|
||||
new_command = CommandFilter(command_name, alias, None)
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable)
|
||||
if not add_to_event_filters:
|
||||
kwargs['sub_command'] = True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
new_command.init_handler_md(handler_md)
|
||||
if add_to_event_filters:
|
||||
# 裸指令
|
||||
handler_md.event_filters.append(new_command)
|
||||
|
||||
handler_md.event_filters.append(new_command)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_command_group(command_group_name: str = None, *args):
|
||||
'''注册一个 CommandGroup'''
|
||||
|
||||
new_group = None
|
||||
def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
'''注册一个自定义的 CustomFilter
|
||||
|
||||
Args:
|
||||
custom_type_filter: 在裸指令时为CustomFilter对象
|
||||
在指令组时为父指令的RegisteringCommandable对象,即self或者command_group的返回
|
||||
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
|
||||
'''
|
||||
add_to_event_filters = False
|
||||
raise_error = True
|
||||
|
||||
# 判断是否是指令组,指令组则添加到指令组的CommandGroupFilter对象中在waking_check的时候一起判断
|
||||
if isinstance(custom_type_filter, RegisteringCommandable):
|
||||
# 子指令, 此时函数为RegisteringCommandable对象的方法,首位参数为RegisteringCommandable对象的self。
|
||||
parent_register_commandable = custom_type_filter
|
||||
custom_filter = args[0]
|
||||
if len(args) > 1:
|
||||
raise_error = args[1]
|
||||
else:
|
||||
# 裸指令
|
||||
add_to_event_filters = True
|
||||
custom_filter = custom_type_filter
|
||||
if args:
|
||||
raise_error = args[0]
|
||||
|
||||
if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)):
|
||||
custom_filter = custom_filter(raise_error)
|
||||
|
||||
def decorator(awaitable):
|
||||
# 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。
|
||||
if not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) or \
|
||||
(add_to_event_filters and isinstance(awaitable, RegisteringCommandable)):
|
||||
# 指令组 与 根指令组,添加到本层的grouphandle中一起判断
|
||||
awaitable.parent_group.add_custom_filter(custom_filter)
|
||||
else:
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
|
||||
if not add_to_event_filters and not isinstance(awaitable, RegisteringCommandable):
|
||||
# 底层子指令
|
||||
handle_full_name = get_handler_full_name(awaitable)
|
||||
for sub_handle in parent_register_commandable.parent_group.sub_command_filters:
|
||||
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
||||
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
||||
sub_handle_md = sub_handle.get_handler_md()
|
||||
if sub_handle_md and sub_handle_md.handler_full_name == handle_full_name:
|
||||
sub_handle.add_custom_filter(custom_filter)
|
||||
|
||||
else:
|
||||
# 裸指令
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(custom_filter)
|
||||
|
||||
return awaitable
|
||||
return decorator
|
||||
|
||||
def register_command_group(
|
||||
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
||||
):
|
||||
'''注册一个 CommandGroup
|
||||
'''
|
||||
new_group = None
|
||||
if isinstance(command_group_name, RegisteringCommandable):
|
||||
# 子指令组
|
||||
new_group = CommandGroupFilter(args[0])
|
||||
new_group = CommandGroupFilter(sub_command, alias, parent_group=command_group_name.parent_group)
|
||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||
else:
|
||||
# 根指令组
|
||||
new_group = CommandGroupFilter(command_group_name)
|
||||
add_to_event_filters = True
|
||||
new_group = CommandGroupFilter(command_group_name, alias)
|
||||
|
||||
def decorator(obj):
|
||||
if add_to_event_filters:
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj)
|
||||
handler_md.event_filters.append(new_group)
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
|
||||
@@ -83,36 +158,37 @@ def register_command_group(command_group_name: str = None, *args):
|
||||
|
||||
class RegisteringCommandable():
|
||||
'''用于指令组级联注册'''
|
||||
group = register_command_group
|
||||
command = register_command
|
||||
group: CommandGroupFilter = register_command_group
|
||||
command: CommandFilter = register_command
|
||||
custom_filter = register_custom_filter
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
|
||||
def register_event_message_type(event_message_type: EventMessageType):
|
||||
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
|
||||
'''注册一个 EventMessageType'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
|
||||
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs):
|
||||
'''注册一个 PlatformAdapterType'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_regex(regex: str):
|
||||
def register_regex(regex: str, **kwargs):
|
||||
'''注册一个 Regex'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(RegexFilter(regex))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -123,9 +199,121 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
|
||||
permission_type: PermissionType
|
||||
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
|
||||
'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_request(**kwargs):
|
||||
'''当有 LLM 请求时的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
|
||||
@on_llm_request()
|
||||
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
|
||||
request.system_prompt += "你是一个猫娘..."
|
||||
```
|
||||
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_response(**kwargs):
|
||||
'''当有 LLM 请求后的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.provider import LLMResponse
|
||||
|
||||
@on_llm_response()
|
||||
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
|
||||
...
|
||||
```
|
||||
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
|
||||
```
|
||||
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
|
||||
async def get_weather(event: AstrMessageEvent, location: str):
|
||||
\'\'\'获取天气信息。
|
||||
|
||||
Args:
|
||||
location(string): 地点
|
||||
\'\'\'
|
||||
# 处理逻辑
|
||||
```
|
||||
|
||||
可接受的参数类型有:string, number, object, array, boolean。
|
||||
|
||||
返回值:
|
||||
- 返回 str:结果会被加入下一次 LLM 请求的 prompt 中,用于让 LLM 总结工具返回的结果
|
||||
- 返回 None:结果不会被加入下一次 LLM 请求的 prompt 中。
|
||||
|
||||
可以使用 yield 发送消息、终止事件。
|
||||
|
||||
发送消息:请参考文档。
|
||||
|
||||
终止事件:
|
||||
```
|
||||
event.stop_event()
|
||||
yield
|
||||
```
|
||||
'''
|
||||
|
||||
name_ = name
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
docstring = docstring_parser.parse(awaitable.__doc__)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
|
||||
args.append({
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description
|
||||
})
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_decorating_result(**kwargs):
|
||||
'''在发送消息前的事件'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_after_message_sent(**kwargs):
|
||||
'''在消息发送后的事件'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
star_map: Dict[str, StarMetadata] = {}
|
||||
@@ -11,7 +12,7 @@ star_map: Dict[str, StarMetadata] = {}
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
'''
|
||||
Star 的元数据。
|
||||
插件的元数据。
|
||||
'''
|
||||
name: str
|
||||
author: str # 插件作者
|
||||
@@ -20,18 +21,27 @@ class StarMetadata:
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
star_cls_type: type = None
|
||||
'''Star 的类对象的类型'''
|
||||
'''插件的类对象的类型'''
|
||||
module_path: str = None
|
||||
'''Star 的模块路径'''
|
||||
'''插件的模块路径'''
|
||||
|
||||
star_cls: object = None
|
||||
'''Star 的类对象'''
|
||||
'''插件的类对象'''
|
||||
module: ModuleType = None
|
||||
'''Star 的模块对象'''
|
||||
'''插件的模块对象'''
|
||||
root_dir_name: str = None
|
||||
'''Star 的根目录名'''
|
||||
'''插件的目录名称'''
|
||||
reserved: bool = False
|
||||
'''是否是 AstrBot 的保留 Star'''
|
||||
'''是否是 AstrBot 的保留插件'''
|
||||
|
||||
activated: bool = True
|
||||
'''是否被激活'''
|
||||
|
||||
config: AstrBotConfig = None
|
||||
'''插件配置'''
|
||||
|
||||
star_handler_full_names: List[str] = field(default_factory=list)
|
||||
'''注册的 Handler 的全名列表'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
@@ -1,31 +1,117 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, List, Dict
|
||||
import enum
|
||||
import heapq
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
|
||||
star_handlers_registry: List[StarHandlerMetadata] = []
|
||||
T = TypeVar('T', bound='StarHandlerMetadata')
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
'''用于存储所有的 Star Handler'''
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
'''用于快速查找。key 是 handler_full_name'''
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
'''用于快速查找。key 是 handler_full_name'''
|
||||
_handlers = []
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
'''添加一个 Handler'''
|
||||
if 'priority' not in handler.extras_configs:
|
||||
handler.extras_configs['priority'] = 0
|
||||
|
||||
heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler))
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
|
||||
def _print_handlers(self):
|
||||
'''打印所有的 Handler'''
|
||||
for _, handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]:
|
||||
'''通过事件类型获取 Handler'''
|
||||
handlers = [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
if handler.event_type == event_type and
|
||||
(not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated))
|
||||
]
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
'''通过 Handler 的全名获取 Handler'''
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
|
||||
'''通过模块名获取 Handler'''
|
||||
return [handler for _, handler in self._handlers if handler.handler_module_path == module_name]
|
||||
|
||||
def clear(self):
|
||||
'''清空所有的 Handler'''
|
||||
self.star_handlers_map.clear()
|
||||
self._handlers.clear()
|
||||
|
||||
def remove(self, handler: StarHandlerMetadata):
|
||||
'''删除一个 Handler'''
|
||||
# self._handlers.remove(handler)
|
||||
for i, h in enumerate(self._handlers):
|
||||
if h[1] == handler:
|
||||
self._handlers.pop(i)
|
||||
break
|
||||
try:
|
||||
del self.star_handlers_map[handler.handler_full_name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
'''使 StarHandlerRegistry 支持迭代'''
|
||||
return (handler for _, handler in self._handlers)
|
||||
|
||||
def __len__(self):
|
||||
'''返回 Handler 的数量'''
|
||||
return len(self._handlers)
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
class EventType(enum.Enum):
|
||||
'''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等
|
||||
|
||||
用于对 Handler 的职能分组。
|
||||
'''
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata():
|
||||
'''描述一个 Star 所注册的某一个 Handler。'''
|
||||
|
||||
event_type: EventType
|
||||
'''Handler 的事件类型'''
|
||||
|
||||
handler_full_name: str
|
||||
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
|
||||
|
||||
handler_name: str
|
||||
'''Handler 的名字,也就是方法名'''
|
||||
|
||||
handler_module_str: str
|
||||
handler_module_path: str
|
||||
'''Handler 所在的模块路径。'''
|
||||
|
||||
handler: Awaitable
|
||||
'''Handler 的函数对象,应当是一个异步函数'''
|
||||
|
||||
event_filters: List[HandlerFilter]
|
||||
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
|
||||
'''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件'''
|
||||
|
||||
desc: str = ""
|
||||
'''Handler 的描述信息'''
|
||||
|
||||
extras_configs: dict = field(default_factory=dict)
|
||||
'''插件注册的一些其他的信息, 如 priority 等'''
|
||||
|
||||
def __lt__(self, other: StarHandlerMetadata):
|
||||
'''定义小于运算符以支持优先队列'''
|
||||
return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0)
|
||||
@@ -1,20 +1,24 @@
|
||||
import inspect
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from pip import main as pip_main
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
|
||||
class PluginManager:
|
||||
def __init__(
|
||||
@@ -25,12 +29,22 @@ class PluginManager:
|
||||
self.updator = PluginUpdator(config['plugin_repo_mirror'])
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
|
||||
'''存储插件的路径。即 data/plugins'''
|
||||
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
|
||||
'''存储插件配置的路径。data/config'''
|
||||
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
|
||||
'''保留插件的路径。在 packages 目录下'''
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
'''插件配置 Schema 文件名'''
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
|
||||
classes = []
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
for (name, _) in clsmembers:
|
||||
@@ -89,21 +103,12 @@ class PluginManager:
|
||||
plugin_path = os.path.join(plugin_dir, p)
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
|
||||
pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
def _update_plugin_dept(self, path):
|
||||
'''更新插件的依赖'''
|
||||
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
|
||||
if self.config.pip_install_arg:
|
||||
args.extend(self.config.pip_install_arg)
|
||||
result_code = pip_main(args)
|
||||
if result_code != 0:
|
||||
raise Exception(str(result_code))
|
||||
|
||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
|
||||
'''v3.4.0 以前的方式载入插件元数据
|
||||
|
||||
@@ -123,7 +128,7 @@ class PluginManager:
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata:
|
||||
raise Exception("插件元数据信息不完整。")
|
||||
raise Exception("插件元数据信息不完整。name, desc, version, author 是必须的字段。")
|
||||
metadata = StarMetadata(
|
||||
name=metadata['name'],
|
||||
author=metadata['author'],
|
||||
@@ -134,28 +139,68 @@ class PluginManager:
|
||||
|
||||
return metadata
|
||||
|
||||
def reload(self):
|
||||
'''扫描并加载所有的 Star'''
|
||||
star_handlers_registry.clear()
|
||||
async def reload(self, specified_plugin_name=None):
|
||||
'''扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件'''
|
||||
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
if smd.name == specified_plugin_name:
|
||||
specified_module_path = smd.module_path
|
||||
break
|
||||
|
||||
# 终止插件
|
||||
if not specified_module_path:
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
smd.star_cls.__del__()
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
if smd:
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
try:
|
||||
del sys.modules[specified_module_path]
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {specified_module_path} 未载入")
|
||||
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
return False, "未找到任何插件模块"
|
||||
|
||||
fail_rec = ""
|
||||
|
||||
# 导入 Star 模块,并尝试实例化 Star 类
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
alter_cmd = sp.get("alter_cmd", {})
|
||||
|
||||
# 导入插件模块,并尝试实例化插件类
|
||||
for plugin_module in plugin_modules:
|
||||
try:
|
||||
module_str = plugin_module['module']
|
||||
# module_path = plugin_module['module_path']
|
||||
root_dir_name = plugin_module['pname']
|
||||
reserved = plugin_module.get('reserved', False)
|
||||
root_dir_name = plugin_module['pname'] # 插件的目录名
|
||||
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path += root_dir_name + "." + module_str
|
||||
|
||||
if specified_module_path and path != specified_module_path:
|
||||
continue
|
||||
|
||||
logger.info(f"正在载入插件 {root_dir_name} ...")
|
||||
|
||||
# 尝试导入模块
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path += root_dir_name + "." + module_str
|
||||
try:
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
@@ -167,27 +212,77 @@ class PluginManager:
|
||||
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
|
||||
continue
|
||||
|
||||
# 检查 _conf_schema.json
|
||||
plugin_config = None
|
||||
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
|
||||
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
|
||||
if os.path.exists(plugin_schema_path):
|
||||
# 加载插件配置
|
||||
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
|
||||
plugin_config = AstrBotConfig(
|
||||
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
|
||||
schema=json.loads(f.read())
|
||||
)
|
||||
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
star_metadata = star_map[path]
|
||||
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
|
||||
star_metadata.module = module
|
||||
star_metadata.root_dir_name = root_dir_name
|
||||
star_metadata.reserved = reserved
|
||||
metadata = star_map[path]
|
||||
|
||||
try:
|
||||
# yaml 文件的元数据优先
|
||||
metadata_yaml = self._load_plugin_metadata(plugin_path=plugin_dir_path)
|
||||
if metadata_yaml:
|
||||
metadata.name = metadata_yaml.name
|
||||
metadata.author = metadata_yaml.author
|
||||
metadata.desc = metadata_yaml.desc
|
||||
metadata.version = metadata_yaml.version
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
|
||||
except TypeError as _:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
else:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
|
||||
for handler in related_handlers:
|
||||
handler.handler = functools.partial(handler.handler, metadata.star_cls)
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler.__module__ == metadata.module_path:
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
|
||||
classes = self._get_classes(module)
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"插件 {root_dir_name} 实例化失败。")
|
||||
raise e
|
||||
|
||||
if plugin_config:
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
|
||||
except TypeError as _:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
else:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
|
||||
metadata = None
|
||||
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_dir_path, plugin_obj=obj)
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
@@ -195,11 +290,45 @@ class PluginManager:
|
||||
metadata.module_path = path
|
||||
star_map[path] = metadata
|
||||
star_registry.append(metadata)
|
||||
logger.debug(f"插件 {root_dir_name} 载入成功。")
|
||||
|
||||
# 禁用/启用插件
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
full_names = []
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(metadata.module_path):
|
||||
full_names.append(handler.handler_full_name)
|
||||
|
||||
# 检查并且植入自定义的权限过滤器(alter_cmd)
|
||||
if metadata.name in alter_cmd and handler.handler_name in alter_cmd[metadata.name]:
|
||||
cmd_type = alter_cmd[metadata.name][handler.handler_name].get("permission", "member")
|
||||
found_permission_filter = False
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, PermissionTypeFilter):
|
||||
if cmd_type == "admin":
|
||||
filter_.permission_type = PermissionType.ADMIN
|
||||
else:
|
||||
filter_.permission_type = PermissionType.MEMBER
|
||||
found_permission_filter = True
|
||||
break
|
||||
if not found_permission_filter:
|
||||
handler.event_filters.append(PermissionTypeFilter(PermissionType.ADMIN if cmd_type == "admin" else PermissionType.MEMBER))
|
||||
|
||||
logger.debug(f"插入权限过滤器 {cmd_type} 到 {metadata.name} 的 {handler.handler_name} 方法。")
|
||||
|
||||
metadata.star_handler_full_names = full_names
|
||||
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
|
||||
logger.error(f"----- 插件 {root_dir_name} 载入失败 -----")
|
||||
errors = traceback.format_exc()
|
||||
for line in errors.split('\n'):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("----------------------------------")
|
||||
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}。\n"
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
@@ -208,14 +337,16 @@ class PluginManager:
|
||||
if not fail_rec:
|
||||
return True, None
|
||||
else:
|
||||
self.failed_plugin_info = fail_rec
|
||||
return False, fail_rec
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
plugin_path = await self.updator.install(repo_url)
|
||||
self._check_plugin_dept_update()
|
||||
async def install_plugin(self, repo_url: str, proxy=""):
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
await self.reload()
|
||||
return plugin_path
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str):
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
@@ -224,22 +355,81 @@ class PluginManager:
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
del star_map[plugin.module_path]
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
del star_map[plugin_module_path]
|
||||
for i, p in enumerate(star_registry):
|
||||
if p.name == plugin_name:
|
||||
del star_registry[i]
|
||||
break
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path):
|
||||
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
|
||||
star_handlers_registry.remove(handler)
|
||||
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)]
|
||||
for k in keys_to_delete:
|
||||
v = star_handlers_registry.star_handlers_map[k]
|
||||
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
async def update_plugin(self, plugin_name: str, proxy = ""):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||
|
||||
await self.updator.update(plugin)
|
||||
await self.updator.update(plugin, proxy=proxy)
|
||||
await self.reload()
|
||||
|
||||
def install_plugin_from_file(self, zip_file_path: str):
|
||||
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
|
||||
async def turn_off_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = False
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if plugin.module_path in inactivated_plugins:
|
||||
inactivated_plugins.remove(plugin.module_path)
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
|
||||
# 启用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
func_tool.active = True
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = True
|
||||
|
||||
|
||||
async def install_plugin_from_file(self, zip_file_path: str):
|
||||
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
||||
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
|
||||
desti_dir = os.path.join(self.plugin_store_path, dir_name)
|
||||
self.updator.unzip_file(zip_file_path, desti_dir)
|
||||
|
||||
# remove the zip
|
||||
@@ -247,5 +437,4 @@ class PluginManager:
|
||||
os.remove(zip_file_path)
|
||||
except BaseException as e:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
|
||||
self._check_plugin_dept_update()
|
||||
await self.reload()
|
||||
|
||||
@@ -15,20 +15,24 @@ class PluginUpdator(RepoZipUpdator):
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
async def install(self, repo_url: str) -> str:
|
||||
async def install(self, repo_url: str, proxy="") -> str:
|
||||
repo_name = self.format_repo_name(repo_url)
|
||||
plugin_path = os.path.join(self.plugin_store_path, repo_name)
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
await self.download_from_repo_url(plugin_path, repo_url, proxy)
|
||||
self.unzip_file(plugin_path + ".zip", plugin_path)
|
||||
|
||||
return plugin_path
|
||||
|
||||
async def update(self, plugin: StarMetadata) -> str:
|
||||
async def update(self, plugin: StarMetadata, proxy="") -> str:
|
||||
repo_url = plugin.repo
|
||||
|
||||
if not repo_url:
|
||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
repo_url = f"{proxy}/{repo_url}"
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
@@ -53,7 +57,6 @@ class PluginUpdator(RepoZipUpdator):
|
||||
|
||||
files = os.listdir(os.path.join(target_dir, update_dir))
|
||||
for f in files:
|
||||
logger.info(f"移动更新文件/目录: {f}")
|
||||
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
|
||||
if os.path.exists(os.path.join(target_dir, f)):
|
||||
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
|
||||
@@ -63,7 +66,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
|
||||
|
||||
try:
|
||||
logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
|
||||
os.remove(zip_path)
|
||||
except BaseException:
|
||||
|
||||
@@ -11,7 +11,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
|
||||
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
try:
|
||||
|
||||
173
astrbot/core/utils/dify_api_client.py
Normal file
173
astrbot/core/utils/dify_api_client.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import json
|
||||
from astrbot.core import logger
|
||||
from aiohttp import ClientSession
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
|
||||
class DifyAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = ClientSession()
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
inputs: Dict,
|
||||
query: str,
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
conversation_id: str = "",
|
||||
files: List[Dict[str, Any]] = [],
|
||||
timeout: float = 60,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
url = f"{self.api_base}/chat-messages"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
logger.info(f"chat_messages payload: {payload}")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk.decode('utf-8')
|
||||
blocks = buffer.split('\n\n')
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith('data:'):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
inputs: Dict,
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
files: List[Dict[str, Any]] = [],
|
||||
timeout: float = 60,
|
||||
):
|
||||
url = f"{self.api_base}/workflows/run"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
logger.info(f"workflow_run payload: {payload}")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk.decode('utf-8')
|
||||
blocks = buffer.split('\n\n')
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith('data:'):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
file_path: str,
|
||||
user: str,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.api_base}/files/upload"
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": open(file_path, "rb"),
|
||||
}
|
||||
async with self.session.post(
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
|
||||
async def get_chat_convs(
|
||||
self,
|
||||
user: str,
|
||||
limit: int = 20
|
||||
):
|
||||
# conversations. GET
|
||||
url = f"{self.api_base}/conversations"
|
||||
payload = {
|
||||
"user": user,
|
||||
"limit": limit,
|
||||
}
|
||||
async with self.session.get(
|
||||
url, params=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def delete_chat_conv(
|
||||
self,
|
||||
user: str,
|
||||
conversation_id: str
|
||||
):
|
||||
# conversation. DELETE
|
||||
url = f"{self.api_base}/conversations/{conversation_id}"
|
||||
payload = {
|
||||
"user": user,
|
||||
}
|
||||
async with self.session.delete(
|
||||
url, json=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def rename(
|
||||
self,
|
||||
conversation_id: str,
|
||||
name: str,
|
||||
user: str,
|
||||
auto_generate: bool = False
|
||||
):
|
||||
# /conversations/:conversation_id/name
|
||||
url = f"{self.api_base}/conversations/{conversation_id}/name"
|
||||
payload = {
|
||||
"user": user,
|
||||
"name": name,
|
||||
"auto_generate": auto_generate,
|
||||
}
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json()
|
||||
|
||||
@@ -5,6 +5,10 @@ import socket
|
||||
import time
|
||||
import aiohttp
|
||||
import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
import psutil
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -40,21 +44,21 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
return False
|
||||
|
||||
|
||||
def save_temp_img(img: Image) -> str:
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
# 获得文件创建时间,清除超过1小时的
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600:
|
||||
if time.time() - ctime > 3600*12:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
print(f"清除临时文件失败: {e}")
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = int(time.time())
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
@@ -64,23 +68,33 @@ def save_temp_img(img: Image) -> str:
|
||||
f.write(img)
|
||||
return p
|
||||
|
||||
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
|
||||
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
|
||||
'''
|
||||
下载图片, 返回 path
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
else:
|
||||
async with session.get(url) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
except aiohttp.client_exceptions.ClientConnectorSSLError:
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers('DEFAULT')
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
@@ -90,21 +104,56 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def download_file(url: str, path: str):
|
||||
async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
'''
|
||||
从指定 url 下载文件到指定路径 path
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, timeout=1800) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
raise e
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers('DEFAULT')
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
|
||||
if show_progress:
|
||||
print()
|
||||
|
||||
|
||||
def file_to_base64(file_path: str) -> str:
|
||||
with open(file_path, "rb") as f:
|
||||
@@ -112,14 +161,34 @@ def file_to_base64(file_path: str) -> str:
|
||||
base64_str = base64.b64encode(data_bytes).decode()
|
||||
return "base64://" + base64_str
|
||||
|
||||
|
||||
def get_local_ip_addresses():
|
||||
ip = ''
|
||||
net_interfaces = psutil.net_if_addrs()
|
||||
network_ips = []
|
||||
|
||||
for interface, addrs in net_interfaces.items():
|
||||
for addr in addrs:
|
||||
if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET
|
||||
network_ips.append(addr.address)
|
||||
|
||||
return network_ips
|
||||
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
async def download_dashboard():
|
||||
'''下载管理面板文件'''
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(('8.8.8.8', 80))
|
||||
ip = s.getsockname()[0]
|
||||
except BaseException:
|
||||
pass
|
||||
finally:
|
||||
s.close()
|
||||
return ip
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user