Compare commits
435 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16ec462abd | ||
|
|
ca55465d3c | ||
|
|
7098c98dde | ||
|
|
f56355da89 | ||
|
|
422160debd | ||
|
|
8062cf406a | ||
|
|
0e802232ec | ||
|
|
f650a9205d | ||
|
|
c85dbb2347 | ||
|
|
a6a79128c8 | ||
|
|
42839627e8 | ||
|
|
267e68a894 | ||
|
|
b32b444438 | ||
|
|
522d0f8313 | ||
|
|
5715e5de67 | ||
|
|
cc6b05e8b3 | ||
|
|
417747d5d0 | ||
|
|
a34f439226 | ||
|
|
b7ca014fd0 | ||
|
|
fa098d585a | ||
|
|
c35a14e3ec | ||
|
|
60651736a5 | ||
|
|
581f9b7bd3 | ||
|
|
124eb04807 | ||
|
|
1d561da7fb | ||
|
|
16e3cd0784 | ||
|
|
a6d91933dc | ||
|
|
445c40f758 | ||
|
|
725a841a3b | ||
|
|
f77c453843 | ||
|
|
ba6718d5bc | ||
|
|
cdb7a1b3fa | ||
|
|
a03c79b89d | ||
|
|
98800d3426 | ||
|
|
a616adaac4 | ||
|
|
ffb5605c99 | ||
|
|
621b556856 | ||
|
|
a3ffecbb2a | ||
|
|
ea64cebe2a | ||
|
|
e79487dd5f | ||
|
|
7fe1c1ec89 | ||
|
|
ab2bbff369 | ||
|
|
ec32825309 | ||
|
|
fd0c182087 | ||
|
|
49fcff1daf | ||
|
|
33b64ddf39 | ||
|
|
4c447aa648 | ||
|
|
ccbfc3d274 | ||
|
|
f83fe43bbb | ||
|
|
19022d67f8 | ||
|
|
58a815dd6b | ||
|
|
bc9fe82860 | ||
|
|
b3cd9bf2b9 | ||
|
|
c5c2b829ec | ||
|
|
9713f96401 | ||
|
|
11f35ebf96 | ||
|
|
7d403aa181 | ||
|
|
64af810a4a | ||
|
|
30821905af | ||
|
|
a9dbff756b | ||
|
|
a6aba10d3d | ||
|
|
9c276c37fe | ||
|
|
6ab6c0fd4c | ||
|
|
b6b0fe3fff | ||
|
|
0d5825bda9 | ||
|
|
cdfb64631a | ||
|
|
d161c281c8 | ||
|
|
8fed5bf2a1 | ||
|
|
98d2e9bd27 | ||
|
|
a03af55edd | ||
|
|
86e2fd9aee | ||
|
|
97bd0e5e58 | ||
|
|
ceaba21986 | ||
|
|
172a77d942 | ||
|
|
4f9d2d2a7d | ||
|
|
8c929f6e05 | ||
|
|
3319b71f5b | ||
|
|
46ec028a5b | ||
|
|
0ce0ef3e5c | ||
|
|
375b071cb2 | ||
|
|
29e1417ff2 | ||
|
|
75db2bd366 | ||
|
|
60ca1efbda | ||
|
|
2692e4978b | ||
|
|
91982eb002 | ||
|
|
bb1dec76fa | ||
|
|
f618b8fcdc | ||
|
|
9147cab75b | ||
|
|
5f07bcc8e6 | ||
|
|
705cf2ea1b | ||
|
|
42c4394484 | ||
|
|
221221a3c1 | ||
|
|
9564166297 | ||
|
|
f5cf3c3c8e | ||
|
|
18f919fb6b | ||
|
|
0924835253 | ||
|
|
20d2e5c578 | ||
|
|
907801605c | ||
|
|
93bc684e8c | ||
|
|
a76c98d57e | ||
|
|
d937a800d0 | ||
|
|
d16f3a227f | ||
|
|
80c9a3eeda | ||
|
|
e68173b451 | ||
|
|
40c27d87f5 | ||
|
|
3c13b5049d | ||
|
|
8288d5e51f | ||
|
|
6e1449900a | ||
|
|
4ffbb18ab4 | ||
|
|
b27271b7a3 | ||
|
|
ebb6665f64 | ||
|
|
e4e5731ffd | ||
|
|
2ab5810f13 | ||
|
|
af934c5d09 | ||
|
|
1e0cf7c112 | ||
|
|
46859c93c9 | ||
|
|
ea1f9cb3b2 | ||
|
|
1641549016 | ||
|
|
716a5dbb8a | ||
|
|
af98cb11c5 | ||
|
|
9a4c2cf341 | ||
|
|
2bc3bcd102 | ||
|
|
d6c663f79d | ||
|
|
9ed86e5f53 | ||
|
|
303e0bc037 | ||
|
|
2cc24019f9 | ||
|
|
83ce774d19 | ||
|
|
2b4ee13b5e | ||
|
|
3a964561f0 | ||
|
|
6959f86632 | ||
|
|
537d373e10 | ||
|
|
cceadf222c | ||
|
|
cf5a4af623 | ||
|
|
39aea11c22 | ||
|
|
c2f1227700 | ||
|
|
900f14d37c | ||
|
|
598249b1d6 | ||
|
|
7ed15bdf04 | ||
|
|
2fc0ec0f72 | ||
|
|
5e9c2a669b | ||
|
|
b310521884 | ||
|
|
288945bf7e | ||
|
|
4fc07cff36 | ||
|
|
b884fe0e86 | ||
|
|
855858c236 | ||
|
|
c11a2a5419 | ||
|
|
773a6572af | ||
|
|
88ad373c9b | ||
|
|
51666464b9 | ||
|
|
5af9cf2f52 | ||
|
|
12c4ae4b10 | ||
|
|
4e1bef414a | ||
|
|
e896c18644 | ||
|
|
c852685e74 | ||
|
|
1e99797df8 | ||
|
|
52a4c986a8 | ||
|
|
c501728204 | ||
|
|
6b067fa6a7 | ||
|
|
a1cd5c53a9 | ||
|
|
a46d487e03 | ||
|
|
3deb6d3ab3 | ||
|
|
af34cdd5d2 | ||
|
|
6e1393235a | ||
|
|
343e0b54b9 | ||
|
|
ecb70cb6f7 | ||
|
|
ca50618af6 | ||
|
|
29c07ba83e | ||
|
|
45fbb83a9f | ||
|
|
ae7ba2df25 | ||
|
|
c3ef57cc32 | ||
|
|
7bb4ca5a14 | ||
|
|
063783d81d | ||
|
|
42116c9b65 | ||
|
|
a36e11973d | ||
|
|
5125568ea2 | ||
|
|
0fa164e50d | ||
|
|
cf814e81ee | ||
|
|
43a45f18ce | ||
|
|
ad51381063 | ||
|
|
0b0e4ce904 | ||
|
|
6a3e04d688 | ||
|
|
4107a17370 | ||
|
|
06b4d8f169 | ||
|
|
1c0c820746 | ||
|
|
d061403a28 | ||
|
|
5c092321a6 | ||
|
|
bdd3f61c1f | ||
|
|
8023557d6e | ||
|
|
074b0ced7a | ||
|
|
3864b1ac9b | ||
|
|
6e9b43457d | ||
|
|
ca1aec8920 | ||
|
|
acac580862 | ||
|
|
673e1b2980 | ||
|
|
f62157be72 | ||
|
|
f894ecf3b6 | ||
|
|
66dd4e28ad | ||
|
|
939dc1b0fb | ||
|
|
56bf5d38a1 | ||
|
|
d09b70b295 | ||
|
|
205180387a | ||
|
|
39c8cfeda5 | ||
|
|
f38a329be5 | ||
|
|
a0cd069539 | ||
|
|
bf306a2f01 | ||
|
|
c31f93a8d1 | ||
|
|
4730ab6309 | ||
|
|
1ae78ca98c | ||
|
|
d2379da478 | ||
|
|
0f64981b20 | ||
|
|
0002e49bb5 | ||
|
|
db13a60274 | ||
|
|
db0f11a359 | ||
|
|
ac7f43520b | ||
|
|
f67b9f5f6e | ||
|
|
c75156c4ce | ||
|
|
10270b5595 | ||
|
|
f7458572ed | ||
|
|
d57b7222b2 | ||
|
|
62e70a673a | ||
|
|
5e9eba6478 | ||
|
|
cb02dfe1a4 | ||
|
|
b50739e1af | ||
|
|
8da1b0212d | ||
|
|
ca1f2acb33 | ||
|
|
c15f966669 | ||
|
|
7705b8781a | ||
|
|
b2502746f0 | ||
|
|
ab68094386 | ||
|
|
bbec701223 | ||
|
|
b29d14e600 | ||
|
|
86e51c5cd1 | ||
|
|
cb8267be3f | ||
|
|
eaed43915c | ||
|
|
bd91fd2c38 | ||
|
|
1203b214cd | ||
|
|
c3fec15f11 | ||
|
|
0545653494 | ||
|
|
db2989bdb4 | ||
|
|
587bd00a19 | ||
|
|
960ff438e8 | ||
|
|
98e7ea85d3 | ||
|
|
2549e44710 | ||
|
|
4d32b563ca | ||
|
|
3a4b732977 | ||
|
|
500909a28e | ||
|
|
07753eb25b | ||
|
|
c6eaf3d010 | ||
|
|
6723fe8271 | ||
|
|
3348b70435 | ||
|
|
35a8527c16 | ||
|
|
7afc475290 | ||
|
|
789bceaa3a | ||
|
|
abbc043969 | ||
|
|
654e5762f1 | ||
|
|
507c3e3629 | ||
|
|
991dfeb2f2 | ||
|
|
26482fc2d3 | ||
|
|
e0ce6d9688 | ||
|
|
946595216a | ||
|
|
864b6bc56d | ||
|
|
6ea5b7581f | ||
|
|
f70b8f0c10 | ||
|
|
1593bcb537 | ||
|
|
bf7fc02c8d | ||
|
|
143702b92b | ||
|
|
c5ccc1a084 | ||
|
|
2ecb52a9b2 | ||
|
|
6439917cbe | ||
|
|
d21c18f657 | ||
|
|
25ef0039e4 | ||
|
|
e6981290bc | ||
|
|
75c3d8abbd | ||
|
|
d88683f498 | ||
|
|
40b9aa3a4c | ||
|
|
b6d1515d58 | ||
|
|
e01d4264e3 | ||
|
|
2117b65487 | ||
|
|
a7823b352f | ||
|
|
c543b62a08 | ||
|
|
3923b87f08 | ||
|
|
b7ecdadb83 | ||
|
|
5ff121e1ed | ||
|
|
f486e5448f | ||
|
|
c5aae98558 | ||
|
|
6d8a3b9897 | ||
|
|
6d98780e19 | ||
|
|
3ad2c46f3f | ||
|
|
a730cee7fd | ||
|
|
77c823c100 | ||
|
|
124f21c67a | ||
|
|
e46cf20dd3 | ||
|
|
4bef5e8313 | ||
|
|
22e93b0af4 | ||
|
|
5aeca9662b | ||
|
|
b996cf1f05 | ||
|
|
878a106877 | ||
|
|
45d36f86fd | ||
|
|
b108ae403a | ||
|
|
887ed66768 | ||
|
|
dac840a887 | ||
|
|
238de4ba8c | ||
|
|
9a7bdade43 | ||
|
|
aa84556204 | ||
|
|
6b68069fcd | ||
|
|
42c7034fb2 | ||
|
|
060c7e0145 | ||
|
|
b5b085dfb1 | ||
|
|
fc06ce9d7f | ||
|
|
d8d81b05a7 | ||
|
|
a60f42b1f2 | ||
|
|
6e18be88d0 | ||
|
|
b45e439c48 | ||
|
|
b87061c18c | ||
|
|
f78aca7752 | ||
|
|
3ccca2aa10 | ||
|
|
6d7c40eb76 | ||
|
|
da4cd7fb65 | ||
|
|
c97cda6b84 | ||
|
|
7a7fd4167a | ||
|
|
dffc1a43d5 | ||
|
|
36897fea1e | ||
|
|
c7b34735f0 | ||
|
|
5b07176c88 | ||
|
|
474b40d660 | ||
|
|
a62901b948 | ||
|
|
25d8746327 | ||
|
|
aff1698223 | ||
|
|
7f8941745f | ||
|
|
b858401098 | ||
|
|
d5a158b80f | ||
|
|
f315f284aa | ||
|
|
c367f5009d | ||
|
|
6db1e63bda | ||
|
|
e22ab2ede6 | ||
|
|
b7d7e0b682 | ||
|
|
96bba15f2f | ||
|
|
fcf965a595 | ||
|
|
e1a20d3c22 | ||
|
|
2abd7d8c5d | ||
|
|
5b8f73cdd7 | ||
|
|
7fd765421f | ||
|
|
d9d94af022 | ||
|
|
790b924e57 | ||
|
|
4a62f877df | ||
|
|
ac47c57bb7 | ||
|
|
3ace4199a1 | ||
|
|
e6bd7524c1 | ||
|
|
699c86e8c1 | ||
|
|
f40fa0ecea | ||
|
|
626f94686b | ||
|
|
752d13b1b1 | ||
|
|
54c0dc1b2b | ||
|
|
c5bc709898 | ||
|
|
ccdbb01513 | ||
|
|
5206d750ac | ||
|
|
a800e3df67 | ||
|
|
ccb1f87a20 | ||
|
|
c111da4681 | ||
|
|
9cc4e97a53 | ||
|
|
dca1c0b0f3 | ||
|
|
f06be6ed21 | ||
|
|
3c8ec2f42e | ||
|
|
7e193f7f52 | ||
|
|
7069b02929 | ||
|
|
66995db927 | ||
|
|
c36054ca1b | ||
|
|
3e07fbf3dc | ||
|
|
bf3fbe3e96 | ||
|
|
0a93d22bc8 | ||
|
|
f5b3d94d16 | ||
|
|
4d1a6994aa | ||
|
|
05c686782c | ||
|
|
85609ea742 | ||
|
|
20dabc0615 | ||
|
|
356dd9bc2b | ||
|
|
cd5d7534c4 | ||
|
|
b4f12fc933 | ||
|
|
cbea387ce0 | ||
|
|
345b155374 | ||
|
|
29d216950e | ||
|
|
321b04772c | ||
|
|
5b924aee98 | ||
|
|
46d44e3405 | ||
|
|
4d5332fe25 | ||
|
|
18bd4c54f4 | ||
|
|
31c7768ca0 | ||
|
|
6ec643e9d1 | ||
|
|
2b39f6f61c | ||
|
|
bf3ca13961 | ||
|
|
82026370ec | ||
|
|
6d49bf5346 | ||
|
|
67431d87fb | ||
|
|
fdf55221e6 | ||
|
|
07f277dd3b | ||
|
|
cf8f0603ca | ||
|
|
5592408ab8 | ||
|
|
a01617b45c | ||
|
|
7abb4087b3 | ||
|
|
dff15cf27a | ||
|
|
aa858137e5 | ||
|
|
45cb143202 | ||
|
|
7a9c6ab8c4 | ||
|
|
e2c26c292d | ||
|
|
be7c3fd00e | ||
|
|
7e5461a2cf | ||
|
|
6ee9010645 | ||
|
|
a23d5be056 | ||
|
|
97a6a1fdc2 | ||
|
|
c8f567347b | ||
|
|
74c1e7f69e | ||
|
|
15a5fc0cae | ||
|
|
f07c54d47c | ||
|
|
70446be108 | ||
|
|
d6d21fca56 | ||
|
|
8d7273924f | ||
|
|
ea64afbaa7 | ||
|
|
45da9837ec | ||
|
|
8c19b7d163 | ||
|
|
ab227a08d0 | ||
|
|
40d6e77964 | ||
|
|
9326e3f1b0 | ||
|
|
0e1eb3daf6 | ||
|
|
05daac12ed | ||
|
|
c5b24b4764 | ||
|
|
cc16548e5f | ||
|
|
0ae61d5865 | ||
|
|
d3bd775a79 | ||
|
|
d921b0f6bd | ||
|
|
0607b95df6 | ||
|
|
98427345cf | ||
|
|
95563c8659 | ||
|
|
d916fda04c | ||
|
|
afa1aa5d93 | ||
|
|
7c1e8ce48c |
@@ -20,4 +20,5 @@ dashboard/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
15
.github/FUNDING.yml
vendored
Normal file
15
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: astrbot
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: ['https://afdian.com/a/astrbot_team']
|
||||
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,5 +1,5 @@
|
||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||
修复了 #XYZ
|
||||
解决了 #XYZ
|
||||
|
||||
### Motivation
|
||||
|
||||
@@ -10,5 +10,10 @@
|
||||
<!--简单解释你的改动-->
|
||||
|
||||
### Check
|
||||
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 我新增/修复/优化的功能经过良好的测试
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||
|
||||
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 👀 我的更改经过良好的测试
|
||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||
- [ ] 😮 我的更改没有引入恶意代码
|
||||
|
||||
63
.github/workflows/auto_release.yml
vendored
63
.github/workflows/auto_release.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
name: Auto Release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
build-and-publish-to-github-release:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -23,13 +23,70 @@ jobs:
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo ${{ github.ref_name }} > dist/assets/version
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Upload to Cloudflare R2
|
||||
env:
|
||||
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
|
||||
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
|
||||
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
|
||||
R2_BUCKET_NAME: "astrbot"
|
||||
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
|
||||
VERSION_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo "Installing rclone..."
|
||||
curl https://rclone.org/install.sh | sudo bash
|
||||
|
||||
echo "Configuring rclone remote..."
|
||||
mkdir -p ~/.config/rclone
|
||||
cat <<EOF > ~/.config/rclone/rclone.conf
|
||||
[r2]
|
||||
type = s3
|
||||
provider = Cloudflare
|
||||
access_key_id = $R2_ACCESS_KEY_ID
|
||||
secret_access_key = $R2_SECRET_ACCESS_KEY
|
||||
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
|
||||
EOF
|
||||
|
||||
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
|
||||
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
|
||||
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
|
||||
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
|
||||
|
||||
- name: Fetch Changelog
|
||||
run: |
|
||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
- name: Create GitHub Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
bodyFile: ${{ env.changelog }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
artifacts: "dashboard/dist.zip"
|
||||
|
||||
build-and-publish-to-pypi:
|
||||
# 构建并发布到 PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install uv
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
uv build
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
uv publish
|
||||
|
||||
35
.github/workflows/docker-image.yml
vendored
35
.github/workflows/docker-image.yml
vendored
@@ -11,24 +11,42 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: 拉取源码
|
||||
- name: Pull The Codes
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
|
||||
- name: 设置 QEMU
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
tag=$(git describe --tags --abbrev=0)
|
||||
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout to latest tag (only on manual trigger)
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
- name: Set Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: 登录到 DockerHub
|
||||
- name: Log in to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: 构建和推送 Docker hub
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: Soulter
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push Docker to DockerHub and Github GHCR
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -36,8 +54,9 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
ghcr.io/soulter/astrbot:latest
|
||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,3 +30,4 @@ packages/python_interpreter/workplace
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
65
README.md
65
README.md
@@ -19,7 +19,6 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||

|
||||

|
||||
|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
@@ -28,14 +27,25 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
[](https://gitcode.com/Soulter/AstrBot)
|
||||
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
>
|
||||
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
<details><summary>1. AstrBot 现已自带知识库能力</summary>
|
||||
|
||||
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
@@ -76,14 +86,29 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
#### 手动部署
|
||||
|
||||
推荐使用 `uv`。
|
||||
> 推荐使用 `uv`。
|
||||
|
||||
首先,安装 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
通过 Git Clone 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
pip install uv
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者,直接通过 uvx 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
mkdir astrbot && cd astrbot
|
||||
uvx astrbot init
|
||||
# uvx astrbot run
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
@@ -96,9 +121,10 @@ uv run main.py
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| Telegram | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 企业微信 | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 微信客服 | ✔ | 私聊 | 文字、图片 |
|
||||
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
@@ -110,21 +136,26 @@ uv run main.py
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| DashScope(阿里云百炼应用) | ✔ | LLMOps | |
|
||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -148,10 +179,11 @@ pre-commit install
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
@@ -173,6 +205,9 @@ _✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
@@ -181,6 +216,10 @@ _✨ WebUI ✨_
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本项目的诞生离不开以下开源项目:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
|
||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
|
||||
1
astrbot/cli/__init__.py
Normal file
1
astrbot/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "3.5.8"
|
||||
59
astrbot/cli/__main__.py
Normal file
59
astrbot/cli/__main__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
AstrBot CLI入口
|
||||
"""
|
||||
|
||||
import click
|
||||
import sys
|
||||
from . import __version__
|
||||
from .commands import init, run, plug, conf
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(__version__, prog_name="AstrBot")
|
||||
def cli() -> None:
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(logo_tmpl)
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot CLI version: {__version__}")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("command_name", required=False, type=str)
|
||||
def help(command_name: str | None) -> None:
|
||||
"""显示命令的帮助信息
|
||||
|
||||
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
|
||||
否则,显示通用帮助信息。
|
||||
"""
|
||||
ctx = click.get_current_context()
|
||||
if command_name:
|
||||
# 查找指定命令
|
||||
command = cli.get_command(ctx, command_name)
|
||||
if command:
|
||||
# 显示特定命令的帮助信息
|
||||
click.echo(command.get_help(ctx))
|
||||
else:
|
||||
click.echo(f"Unknown command: {command_name}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 显示通用帮助信息
|
||||
click.echo(cli.get_help(ctx))
|
||||
|
||||
|
||||
cli.add_command(init)
|
||||
cli.add_command(run)
|
||||
cli.add_command(help)
|
||||
cli.add_command(plug)
|
||||
cli.add_command(conf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
6
astrbot/cli/commands/__init__.py
Normal file
6
astrbot/cli/commands/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .cmd_init import init
|
||||
from .cmd_run import run
|
||||
from .cmd_plug import plug
|
||||
from .cmd_conf import conf
|
||||
|
||||
__all__ = ["init", "run", "plug", "conf"]
|
||||
206
astrbot/cli/commands/cmd_conf.py
Normal file
206
astrbot/cli/commands/cmd_conf.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
import click
|
||||
import hashlib
|
||||
import zoneinfo
|
||||
from typing import Any, Callable
|
||||
from ..utils import get_astrbot_root, check_astrbot_root
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
"""验证日志级别"""
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_port(value: str) -> int:
|
||||
"""验证 Dashboard 端口"""
|
||||
try:
|
||||
port = int(value)
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException("端口必须在 1-65535 范围内")
|
||||
return port
|
||||
except ValueError:
|
||||
raise click.ClickException("端口必须是数字")
|
||||
|
||||
|
||||
def _validate_dashboard_username(value: str) -> str:
|
||||
"""验证 Dashboard 用户名"""
|
||||
if not value:
|
||||
raise click.ClickException("用户名不能为空")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_password(value: str) -> str:
|
||||
"""验证 Dashboard 密码"""
|
||||
if not value:
|
||||
raise click.ClickException("密码不能为空")
|
||||
return hashlib.md5(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def _validate_timezone(value: str) -> str:
|
||||
"""验证时区"""
|
||||
try:
|
||||
zoneinfo.ZoneInfo(value)
|
||||
except Exception:
|
||||
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_callback_api_base(value: str) -> str:
|
||||
"""验证回调接口基址"""
|
||||
if not value.startswith("http://") and not value.startswith("https://"):
|
||||
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
|
||||
return value
|
||||
|
||||
|
||||
# 可通过CLI设置的配置项,配置键到验证器函数的映射
|
||||
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
"timezone": _validate_timezone,
|
||||
"log_level": _validate_log_level,
|
||||
"dashboard.port": _validate_dashboard_port,
|
||||
"dashboard.username": _validate_dashboard_username,
|
||||
"dashboard.password": _validate_dashboard_password,
|
||||
"callback_api_base": _validate_callback_api_base,
|
||||
}
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
"""加载或初始化配置文件"""
|
||||
root = get_astrbot_root()
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
config_path = root / "data" / "cmd_config.json"
|
||||
if not config_path.exists():
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
"""保存配置文件"""
|
||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||
)
|
||||
|
||||
|
||||
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""设置嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts[:-1]:
|
||||
if part not in obj:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
||||
)
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
|
||||
|
||||
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
"""获取嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj
|
||||
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf():
|
||||
"""配置管理命令
|
||||
|
||||
支持的配置项:
|
||||
|
||||
- timezone: 时区设置 (例如: Asia/Shanghai)
|
||||
|
||||
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
|
||||
- dashboard.port: Dashboard 端口
|
||||
|
||||
- dashboard.username: Dashboard 用户名
|
||||
|
||||
- dashboard.password: Dashboard 密码
|
||||
|
||||
- callback_api_base: 回调接口基址
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str):
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
config = _load_config()
|
||||
|
||||
try:
|
||||
old_value = _get_nested_item(config, key)
|
||||
validated_value = CONFIG_VALIDATORS[key](value)
|
||||
_set_nested_item(config, key, validated_value)
|
||||
_save_config(config)
|
||||
|
||||
click.echo(f"配置已更新: {key}")
|
||||
if key == "dashboard.password":
|
||||
click.echo(" 原值: ********")
|
||||
click.echo(" 新值: ********")
|
||||
else:
|
||||
click.echo(f" 原值: {old_value}")
|
||||
click.echo(f" 新值: {validated_value}")
|
||||
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"设置配置失败: {str(e)}")
|
||||
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str = None):
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
try:
|
||||
value = _get_nested_item(config, key)
|
||||
if key == "dashboard.password":
|
||||
value = "********"
|
||||
click.echo(f"{key}: {value}")
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"获取配置失败: {str(e)}")
|
||||
else:
|
||||
click.echo("当前配置:")
|
||||
for key in CONFIG_VALIDATORS.keys():
|
||||
try:
|
||||
value = (
|
||||
"********"
|
||||
if key == "dashboard.password"
|
||||
else _get_nested_item(config, key)
|
||||
)
|
||||
click.echo(f" {key}: {value}")
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
55
astrbot/cli/commands/cmd_init.py
Normal file
55
astrbot/cli/commands/cmd_init.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def initialize_astrbot(astrbot_root) -> None:
|
||||
"""执行 AstrBot 初始化逻辑"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(f"Current Directory: {astrbot_root}")
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
|
||||
paths = {
|
||||
"data": astrbot_root / "data",
|
||||
"config": astrbot_root / "data" / "config",
|
||||
"plugins": astrbot_root / "data" / "plugins",
|
||||
"temp": astrbot_root / "data" / "temp",
|
||||
}
|
||||
|
||||
for name, path in paths.items():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
|
||||
@click.command()
|
||||
def init() -> None:
|
||||
"""初始化 AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = get_astrbot_root()
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"初始化失败: {e!s}")
|
||||
247
astrbot/cli/commands/cmd_plug.py
Normal file
247
astrbot/cli/commands/cmd_plug.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import shutil
|
||||
|
||||
|
||||
from ..utils import (
|
||||
get_git_repo,
|
||||
build_plug_list,
|
||||
manage_plugin,
|
||||
PluginStatus,
|
||||
check_astrbot_root,
|
||||
get_astrbot_root,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
def plug():
|
||||
"""插件管理"""
|
||||
pass
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = get_astrbot_root()
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
return (base / "data").resolve()
|
||||
|
||||
|
||||
def display_plugins(plugins, title=None, color=None):
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
|
||||
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
|
||||
click.echo("-" * 85)
|
||||
|
||||
for p in plugins:
|
||||
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||
click.echo(
|
||||
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||
f"{p['author']:<15} {desc:<30}"
|
||||
)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def new(name: str):
|
||||
"""创建新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(f"插件 {name} 已存在")
|
||||
|
||||
author = click.prompt("请输入插件作者", type=str)
|
||||
desc = click.prompt("请输入插件描述", type=str)
|
||||
version = click.prompt("请输入插件版本", type=str)
|
||||
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
|
||||
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
|
||||
repo = click.prompt("请输入插件仓库:", type=str)
|
||||
if not repo.startswith("http"):
|
||||
raise click.ClickException("仓库地址必须以 http 开头")
|
||||
|
||||
click.echo("下载插件模板...")
|
||||
get_git_repo(
|
||||
"https://github.com/Soulter/helloworld",
|
||||
plug_path,
|
||||
)
|
||||
|
||||
click.echo("重写插件信息...")
|
||||
# 重写 metadata.yaml
|
||||
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"name: {name}\n"
|
||||
f"desc: {desc}\n"
|
||||
f"version: {version}\n"
|
||||
f"author: {author}\n"
|
||||
f"repo: {repo}\n"
|
||||
)
|
||||
|
||||
# 重写 README.md
|
||||
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
|
||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||
|
||||
# 重写 main.py
|
||||
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
new_content = content.replace(
|
||||
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
|
||||
f'@register("{name}", "{author}", "{desc}", "{version}")',
|
||||
)
|
||||
|
||||
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
click.echo(f"插件 {name} 创建成功")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||
def list(all: bool):
|
||||
"""列出插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
# 未发布的插件
|
||||
not_published_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
|
||||
]
|
||||
if not_published_plugins:
|
||||
display_plugins(not_published_plugins, "未发布的插件", "red")
|
||||
|
||||
# 需要更新的插件
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
if need_update_plugins:
|
||||
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
|
||||
|
||||
# 已安装的插件
|
||||
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
|
||||
if installed_plugins:
|
||||
display_plugins(installed_plugins, "已安装的插件", "green")
|
||||
|
||||
# 未安装的插件
|
||||
not_installed_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
|
||||
]
|
||||
if not_installed_plugins and all:
|
||||
display_plugins(not_installed_plugins, "未安装的插件", "blue")
|
||||
|
||||
if (
|
||||
not any([not_published_plugins, need_update_plugins, installed_plugins])
|
||||
and not all
|
||||
):
|
||||
click.echo("未安装任何插件")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
@click.option("--proxy", help="代理服务器地址")
|
||||
def install(name: str, proxy: str | None):
|
||||
"""安装插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def remove(name: str):
|
||||
"""卸载插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||
|
||||
if not plugin or not plugin.get("local_path"):
|
||||
raise click.ClickException(f"插件 {name} 不存在或未安装")
|
||||
|
||||
plugin_path = plugin["local_path"]
|
||||
|
||||
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
|
||||
|
||||
try:
|
||||
shutil.rmtree(plugin_path)
|
||||
click.echo(f"插件 {name} 已卸载")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--proxy", help="Github代理地址")
|
||||
def update(name: str, proxy: str | None):
|
||||
"""更新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
if name:
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
else:
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
|
||||
if not need_update_plugins:
|
||||
click.echo("没有需要更新的插件")
|
||||
return
|
||||
|
||||
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
|
||||
for plugin in need_update_plugins:
|
||||
plugin_name = plugin["name"]
|
||||
click.echo(f"正在更新插件 {plugin_name}...")
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("query")
|
||||
def search(query: str):
|
||||
"""搜索插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
matched_plugins = [
|
||||
p
|
||||
for p in plugins
|
||||
if query.lower() in p["name"].lower()
|
||||
or query.lower() in p["desc"].lower()
|
||||
or query.lower() in p["author"].lower()
|
||||
]
|
||||
|
||||
if not matched_plugins:
|
||||
click.echo(f"未找到匹配 '{query}' 的插件")
|
||||
return
|
||||
|
||||
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
|
||||
63
astrbot/cli/commands/cmd_run.py
Normal file
63
astrbot/cli/commands/cmd_run.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
db = db_helper
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
|
||||
await core_lifecycle.start()
|
||||
|
||||
|
||||
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
|
||||
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
|
||||
@click.command()
|
||||
def run(reload: bool, port: str) -> None:
|
||||
"""运行 AstrBot"""
|
||||
try:
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
astrbot_root = get_astrbot_root()
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
sys.path.insert(0, str(astrbot_root))
|
||||
|
||||
if port:
|
||||
os.environ["DASHBOARD_PORT"] = port
|
||||
|
||||
if reload:
|
||||
click.echo("启用插件自动重载")
|
||||
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
|
||||
18
astrbot/cli/utils/__init__.py
Normal file
18
astrbot/cli/utils/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from .basic import (
|
||||
get_astrbot_root,
|
||||
check_astrbot_root,
|
||||
check_dashboard,
|
||||
)
|
||||
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
__all__ = [
|
||||
"get_astrbot_root",
|
||||
"check_astrbot_root",
|
||||
"check_dashboard",
|
||||
"get_git_repo",
|
||||
"manage_plugin",
|
||||
"build_plug_list",
|
||||
"VersionComparator",
|
||||
"PluginStatus",
|
||||
]
|
||||
67
astrbot/cli/utils/basic.py
Normal file
67
astrbot/cli/utils/basic.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""检查路径是否为 AstrBot 根目录"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
if not (path / ".astrbot").exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""获取Astrbot根目录路径"""
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("管理面板已是最新版本")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
230
astrbot/cli/utils/plugin.py
Normal file
230
astrbot/cli/utils/plugin.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import click
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class PluginStatus(str, Enum):
|
||||
INSTALLED = "已安装"
|
||||
NEED_UPDATE = "需更新"
|
||||
NOT_INSTALLED = "未安装"
|
||||
NOT_PUBLISHED = "未发布"
|
||||
|
||||
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
# 解析仓库信息
|
||||
repo_namespace = url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
# 尝试获取最新的 release
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
try:
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(release_url)
|
||||
resp.raise_for_status()
|
||||
releases = resp.json()
|
||||
|
||||
if releases:
|
||||
# 使用最新的 release
|
||||
download_url = releases[0]["zipball_url"]
|
||||
else:
|
||||
# 没有 release,使用默认分支
|
||||
click.echo(f"正在从默认分支下载 {author}/{repo}")
|
||||
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
except Exception as e:
|
||||
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
|
||||
download_url = url
|
||||
|
||||
# 应用代理
|
||||
if proxy:
|
||||
download_url = f"{proxy}/{download_url}"
|
||||
|
||||
# 下载并解压
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(download_url)
|
||||
if (
|
||||
resp.status_code == 404
|
||||
and "archive/refs/heads/master.zip" in download_url
|
||||
):
|
||||
alt_url = download_url.replace("master.zip", "main.zip")
|
||||
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||
resp = client.get(alt_url)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
zip_content = BytesIO(resp.content)
|
||||
with ZipFile(zip_content) as z:
|
||||
z.extractall(temp_dir)
|
||||
namelist = z.namelist()
|
||||
root_dir = Path(namelist[0]).parts[0] if namelist else ""
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.move(temp_dir / root_dir, target_path)
|
||||
finally:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
"""从 metadata.yaml 文件加载插件元数据
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||
"""
|
||||
yaml_path = plugin_dir / "metadata.yaml"
|
||||
if yaml_path.exists():
|
||||
try:
|
||||
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
|
||||
except Exception as e:
|
||||
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
|
||||
return {}
|
||||
|
||||
|
||||
def build_plug_list(plugins_dir: Path) -> list:
|
||||
"""构建插件列表,包含本地和在线插件信息
|
||||
|
||||
Args:
|
||||
plugins_dir (Path): 插件目录路径
|
||||
|
||||
Returns:
|
||||
list: 包含插件信息的字典列表
|
||||
"""
|
||||
# 获取本地插件信息
|
||||
result = []
|
||||
if plugins_dir.exists():
|
||||
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||
plugin_dir = plugins_dir / plugin_name
|
||||
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
resp = client.get("https://api.soulter.top/astrbot/plugins")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
# 与在线插件比对,更新状态
|
||||
online_plugin_names = {plugin["name"] for plugin in online_plugins}
|
||||
for local_plugin in result:
|
||||
if local_plugin["name"] in online_plugin_names:
|
||||
# 查找对应的在线插件
|
||||
online_plugin = next(
|
||||
p for p in online_plugins if p["name"] == local_plugin["name"]
|
||||
)
|
||||
if (
|
||||
VersionComparator.compare_version(
|
||||
local_plugin["version"], online_plugin["version"]
|
||||
)
|
||||
< 0
|
||||
):
|
||||
local_plugin["status"] = PluginStatus.NEED_UPDATE
|
||||
else:
|
||||
# 本地插件未在线上发布
|
||||
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
|
||||
|
||||
# 添加未安装的在线插件
|
||||
for online_plugin in online_plugins:
|
||||
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
|
||||
result.append(online_plugin)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def manage_plugin(
|
||||
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
||||
) -> None:
|
||||
"""安装或更新插件
|
||||
|
||||
Args:
|
||||
plugin (dict): 插件信息字典
|
||||
plugins_dir (Path): 插件目录
|
||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||
proxy (str, optional): 代理服务器地址
|
||||
"""
|
||||
plugin_name = plugin["name"]
|
||||
repo_url = plugin["repo"]
|
||||
|
||||
# 如果是更新且有本地路径,直接使用本地路径
|
||||
if is_update and plugin.get("local_path"):
|
||||
target_path = Path(plugin["local_path"])
|
||||
else:
|
||||
target_path = plugins_dir / plugin_name
|
||||
|
||||
backup_path = Path(f"{target_path}_backup") if is_update else None
|
||||
|
||||
# 检查插件是否存在
|
||||
if is_update and not target_path.exists():
|
||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||
|
||||
# 备份现有插件
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
if is_update:
|
||||
shutil.copytree(target_path, backup_path)
|
||||
|
||||
try:
|
||||
click.echo(
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
||||
)
|
||||
get_git_repo(repo_url, target_path, proxy)
|
||||
|
||||
# 更新成功,删除备份
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||
except Exception as e:
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
if is_update and backup_path.exists():
|
||||
shutil.move(backup_path, target_path)
|
||||
raise click.ClickException(
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
||||
)
|
||||
92
astrbot/cli/utils/version_comparator.py
Normal file
92
astrbot/cli/utils/version_comparator.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
拷贝自 astrbot.core.utils.version_comparator
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
@staticmethod
|
||||
def compare_version(v1: str, v2: str) -> int:
|
||||
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||
|
||||
参考: https://semver.org/lang/zh-CN/
|
||||
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.lower().replace("v", "")
|
||||
v2 = v2.lower().replace("v", "")
|
||||
|
||||
def split_version(version):
|
||||
match = re.match(
|
||||
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
|
||||
version,
|
||||
)
|
||||
if not match:
|
||||
return [], None
|
||||
major_minor_patch = match.group(1).split(".")
|
||||
prerelease = match.group(2)
|
||||
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||
parts = [int(x) for x in major_minor_patch]
|
||||
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||
return parts, prerelease
|
||||
|
||||
v1_parts, v1_prerelease = split_version(v1)
|
||||
v2_parts, v2_prerelease = split_version(v2)
|
||||
|
||||
# 比较数字部分
|
||||
length = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||
|
||||
for i in range(length):
|
||||
if v1_parts[i] > v2_parts[i]:
|
||||
return 1
|
||||
elif v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
|
||||
|
||||
if p1 is None and p2 is not None:
|
||||
return -1
|
||||
elif p1 is not None and p2 is None:
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, str):
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
return 0 # 数字部分和预发布标签都相同
|
||||
|
||||
@staticmethod
|
||||
def _split_prerelease(prerelease):
|
||||
if not prerelease:
|
||||
return None
|
||||
parts = prerelease.split(".")
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
result.append(int(part))
|
||||
else:
|
||||
result.append(part)
|
||||
return result
|
||||
@@ -7,27 +7,28 @@ from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs("data", exist_ok=True)
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = (
|
||||
SharedPreferences()
|
||||
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences()
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
astrbot_config.get("pip_install_arg", ""),
|
||||
astrbot_config.get("pypi_index_url", None),
|
||||
)
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -42,11 +43,10 @@ class AstrBotConfig(dict):
|
||||
"""不存在时载入默认配置"""
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith("/ufeff"): # remove BOM
|
||||
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
@@ -83,23 +83,61 @@ class AstrBotConfig(dict):
|
||||
return conf
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
"""检查配置完整性,如果有新的配置项则返回 True"""
|
||||
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
||||
has_new = False
|
||||
|
||||
# 创建一个新的有序字典以保持参考配置的顺序
|
||||
new_conf = {}
|
||||
|
||||
# 先按照参考配置的顺序添加配置项
|
||||
for key, value in refer_conf.items():
|
||||
if key not in conf:
|
||||
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}")
|
||||
# 配置项不存在,插入默认值
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
||||
conf[key] = value
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
if conf[key] is None:
|
||||
conf[key] = value
|
||||
# 配置项为 None,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
has_new |= self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
# 递归检查子配置项
|
||||
if not isinstance(conf[key], dict):
|
||||
# 类型不匹配,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
# 递归检查并同步顺序
|
||||
child_has_new = self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
new_conf[key] = conf[key]
|
||||
has_new |= child_has_new
|
||||
else:
|
||||
# 直接使用现有配置
|
||||
new_conf[key] = conf[key]
|
||||
|
||||
# 检查是否存在参考配置中没有的配置项
|
||||
for key in list(conf.keys()):
|
||||
if key not in refer_conf:
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
|
||||
has_new = True
|
||||
|
||||
# 顺序不一致也算作变更
|
||||
if list(conf.keys()) != list(new_conf.keys()):
|
||||
if path:
|
||||
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
|
||||
else:
|
||||
logger.info("检查到配置项顺序不一致,已重新排序")
|
||||
has_new = True
|
||||
|
||||
# 更新原始配置
|
||||
conf.clear()
|
||||
conf.update(new_conf)
|
||||
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: Dict = None):
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.5.5"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.15"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
@@ -37,12 +40,15 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
"empty_mention_waiting": True,
|
||||
"empty_mention_waiting_need_reply": True,
|
||||
"friend_message_needs_wake_prefix": False,
|
||||
"ignore_bot_self_message": False,
|
||||
"ignore_at_all": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
@@ -54,6 +60,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -63,6 +70,7 @@ DEFAULT_CONFIG = {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -88,6 +96,7 @@ DEFAULT_CONFIG = {
|
||||
"t2i_word_threshold": 150,
|
||||
"t2i_strategy": "remote",
|
||||
"t2i_endpoint": "",
|
||||
"t2i_use_file_service": False,
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
@@ -104,6 +113,7 @@ DEFAULT_CONFIG = {
|
||||
"knowledge_db": {},
|
||||
"persona": [],
|
||||
"timezone": "",
|
||||
"callback_api_base": "",
|
||||
}
|
||||
|
||||
|
||||
@@ -140,6 +150,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "0.0.0.0",
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
@@ -150,6 +161,29 @@ CONFIG_METADATA_2 = {
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
"wechatpadpro(微信)": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
"weixin_official_account(微信公众平台)": {
|
||||
"id": "weixin_official_account",
|
||||
"type": "weixin_official_account",
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6194,
|
||||
"active_send_mode": False,
|
||||
},
|
||||
"wecom(企业微信)": {
|
||||
"id": "wecom",
|
||||
"type": "wecom",
|
||||
@@ -158,6 +192,7 @@ CONFIG_METADATA_2 = {
|
||||
"secret": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"kf_name": "",
|
||||
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
@@ -186,19 +221,57 @@ CONFIG_METADATA_2 = {
|
||||
"start_message": "Hello, I'm AstrBot!",
|
||||
"telegram_api_base_url": "https://api.telegram.org/bot",
|
||||
"telegram_file_base_url": "https://api.telegram.org/file/bot",
|
||||
"telegram_command_register": True,
|
||||
"telegram_command_auto_refresh": True,
|
||||
"telegram_command_register_interval": 300,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"active_send_mode": {
|
||||
"description": "是否换用主动发送接口",
|
||||
"type": "bool",
|
||||
"desc": "只有企业认证的公众号才能主动发送。主动发送接口的限制会少一些。",
|
||||
},
|
||||
"wpp_active_message_poll": {
|
||||
"description": "是否启用主动消息轮询",
|
||||
"type": "bool",
|
||||
"hint": "只有当你发现微信消息没有按时同步到 AstrBot 时,才需要启用这个功能,默认不启用。",
|
||||
},
|
||||
"wpp_active_message_poll_interval": {
|
||||
"description": "主动消息轮询间隔",
|
||||
"type": "int",
|
||||
"hint": "主动消息轮询间隔,单位为秒,默认 3 秒,最大不要超过 60 秒,否则可能被认为是旧消息。",
|
||||
},
|
||||
"kf_name": {
|
||||
"description": "微信客服账号名",
|
||||
"type": "string",
|
||||
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取",
|
||||
},
|
||||
"telegram_token": {
|
||||
"description": "Bot Token",
|
||||
"type": "string",
|
||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||
},
|
||||
"telegram_command_register": {
|
||||
"description": "Telegram 命令注册",
|
||||
"type": "bool",
|
||||
"hint": "启用后,AstrBot 将会自动注册 Telegram 命令。",
|
||||
},
|
||||
"telegram_command_auto_refresh": {
|
||||
"description": "Telegram 命令自动刷新",
|
||||
"type": "bool",
|
||||
"hint": "启用后,AstrBot 将会在运行时自动刷新 Telegram 命令。(单独设置此项无效)",
|
||||
},
|
||||
"telegram_command_register_interval": {
|
||||
"description": "Telegram 命令自动刷新间隔",
|
||||
"type": "int",
|
||||
"hint": "Telegram 命令自动刷新间隔,单位为秒。",
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"description": "机器人名称",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
|
||||
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -218,7 +291,7 @@ CONFIG_METADATA_2 = {
|
||||
"secret": {
|
||||
"description": "secret",
|
||||
"type": "string",
|
||||
"hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。",
|
||||
"hint": "必填项。",
|
||||
},
|
||||
"enable_group_c2c": {
|
||||
"description": "启用消息列表单聊",
|
||||
@@ -240,6 +313,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"ws_reverse_token": {
|
||||
"description": "反向 Websocket Token",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
@@ -281,9 +359,14 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
||||
},
|
||||
"empty_mention_waiting": {
|
||||
"description": "只 @ 机器人是否触发等待回复",
|
||||
"description": "只 @ 机器人是否触发等待",
|
||||
"type": "bool",
|
||||
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待回复,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
|
||||
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
|
||||
},
|
||||
"empty_mention_waiting_need_reply": {
|
||||
"description": "只 @ 机器人触发等待时是否需要回复提醒",
|
||||
"type": "bool",
|
||||
"hint": "在上面一个配置项中,如果启用了触发等待,启用此项后,机器人会使用 LLM 生成一条回复。否则,将不回复而只是等待。",
|
||||
},
|
||||
"friend_message_needs_wake_prefix": {
|
||||
"description": "私聊消息是否需要唤醒前缀",
|
||||
@@ -295,6 +378,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"ignore_at_all": {
|
||||
"description": "是否忽略 @ 全体成员",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人会忽略 @ 全体成员 的消息事件。",
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
@@ -451,6 +539,7 @@ CONFIG_METADATA_2 = {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
@@ -459,9 +548,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"Azure_OpenAI": {
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
@@ -471,9 +561,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"xAI(grok)": {
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
@@ -482,9 +573,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "grok-2-latest",
|
||||
},
|
||||
},
|
||||
"Anthropic(claude)": {
|
||||
"Anthropic": {
|
||||
"id": "claude",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
@@ -497,6 +589,7 @@ CONFIG_METADATA_2 = {
|
||||
"Ollama": {
|
||||
"id": "ollama_default",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
@@ -504,9 +597,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
},
|
||||
"LM_Studio": {
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://localhost:1234/v1",
|
||||
@@ -517,6 +611,7 @@ CONFIG_METADATA_2 = {
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
@@ -525,9 +620,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
},
|
||||
"Gemini(googlegenai原生)": {
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
@@ -538,16 +634,21 @@ CONFIG_METADATA_2 = {
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
"gm_url_context": False,
|
||||
"gm_safety_settings": {
|
||||
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
@@ -556,9 +657,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-chat",
|
||||
},
|
||||
},
|
||||
"Zhipu(智谱)": {
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
@@ -567,9 +669,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
},
|
||||
"SiliconFlow(硅基流动)": {
|
||||
"硅基流动": {
|
||||
"id": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
@@ -578,9 +681,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
},
|
||||
},
|
||||
"MoonShot(Kimi)": {
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
@@ -589,9 +693,22 @@ CONFIG_METADATA_2 = {
|
||||
"model": "moonshot-v1-8k",
|
||||
},
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"id": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
},
|
||||
},
|
||||
"LLMTuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"base_model_path": "",
|
||||
"adapter_model_path": "",
|
||||
@@ -602,6 +719,7 @@ CONFIG_METADATA_2 = {
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
@@ -611,9 +729,10 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"Dashscope(阿里云百炼应用)": {
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"type": "dashscope",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"dashscope_app_type": "agent",
|
||||
"dashscope_api_key": "",
|
||||
@@ -629,6 +748,7 @@ CONFIG_METADATA_2 = {
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
@@ -637,6 +757,7 @@ CONFIG_METADATA_2 = {
|
||||
"Whisper(API)": {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_api",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
@@ -644,22 +765,25 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"sensevoice(本地加载)": {
|
||||
"SenseVoice(本地加载)": {
|
||||
"sensevoice_hint": "(不用修改我)",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "sensevoice",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"stt_model": "iic/SenseVoiceSmall",
|
||||
"is_emotion": False,
|
||||
},
|
||||
"OpenAI_TTS(API)": {
|
||||
"OpenAI TTS(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
@@ -667,43 +791,232 @@ CONFIG_METADATA_2 = {
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge_TTS": {
|
||||
"Edge TTS": {
|
||||
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"type": "edge_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
|
||||
"timeout": 20,
|
||||
},
|
||||
"GSVI_TTS(API)": {
|
||||
"GSVI TTS(API)": {
|
||||
"id": "gsvi_tts",
|
||||
"type": "gsvi_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:5000",
|
||||
"character": "",
|
||||
"emotion": "default",
|
||||
"enable": False,
|
||||
"timeout": 20,
|
||||
},
|
||||
"FishAudio_TTS(API)": {
|
||||
"FishAudio TTS(API)": {
|
||||
"id": "fishaudio_tts",
|
||||
"type": "fishaudio_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "https://api.fish.audio/v1",
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"timeout": "20",
|
||||
},
|
||||
"阿里云百炼_TTS(API)": {
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"id": "dashscope_tts",
|
||||
"type": "dashscope_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"model": "cosyvoice-v1",
|
||||
"dashscope_tts_voice": "loongstella",
|
||||
"timeout": "20",
|
||||
},
|
||||
"Azure TTS": {
|
||||
"id": "azure_tts",
|
||||
"type": "azure_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": True,
|
||||
"azure_tts_voice": "zh-CN-YunxiaNeural",
|
||||
"azure_tts_style": "cheerful",
|
||||
"azure_tts_role": "Boy",
|
||||
"azure_tts_rate": "1",
|
||||
"azure_tts_volume": "100",
|
||||
"azure_tts_subscription_key": "",
|
||||
"azure_tts_region": "eastus",
|
||||
},
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
"type": "minimax_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "https://api.minimax.chat/v1/t2a_v2",
|
||||
"minimax-group-id": "",
|
||||
"model": "speech-02-turbo",
|
||||
"minimax-langboost": "auto",
|
||||
"minimax-voice-speed": 1.0,
|
||||
"minimax-voice-vol": 1.0,
|
||||
"minimax-voice-pitch": 0,
|
||||
"minimax-is-timber-weight": False,
|
||||
"minimax-voice-id": "female-shaonv",
|
||||
"minimax-timber-weight": '[\n {\n "voice_id": "Chinese (Mandarin)_Warm_Girl",\n "weight": 25\n },\n {\n "voice_id": "Chinese (Mandarin)_BashfulGirl",\n "weight": 50\n }\n]',
|
||||
"minimax-voice-emotion": "neutral",
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
},
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
"type": "volcengine_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"appid": "",
|
||||
"volcengine_cluster": "volcano_tts",
|
||||
"volcengine_voice_type": "",
|
||||
"volcengine_speed_ratio": 1.0,
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1536,
|
||||
"timeout": 20,
|
||||
},
|
||||
"Gemini Embedding": {
|
||||
"id": "gemini_embedding",
|
||||
"type": "gemini_embedding",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "gemini-embedding-exp-03-07",
|
||||
"embedding_dimensions": 768,
|
||||
"timeout": 20,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"embedding_dimensions": {
|
||||
"description": "嵌入维度",
|
||||
"type": "int",
|
||||
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
||||
},
|
||||
"embedding_model": {
|
||||
"description": "嵌入模型",
|
||||
"type": "string",
|
||||
"hint": "嵌入模型名称。",
|
||||
},
|
||||
"embedding_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
},
|
||||
"embedding_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
},
|
||||
"volcengine_cluster": {
|
||||
"type": "string",
|
||||
"description": "火山引擎集群",
|
||||
"hint": "若使用语音复刻大模型,可选volcano_icl或volcano_icl_concurr,默认使用volcano_tts",
|
||||
},
|
||||
"volcengine_voice_type": {
|
||||
"type": "string",
|
||||
"description": "火山引擎音色",
|
||||
"hint": "输入声音id(Voice_type)",
|
||||
},
|
||||
"volcengine_speed_ratio": {
|
||||
"type": "float",
|
||||
"description": "语速设置",
|
||||
"hint": "语速设置,范围为 0.2 到 3.0,默认值为 1.0",
|
||||
},
|
||||
"volcengine_volume_ratio": {
|
||||
"type": "float",
|
||||
"description": "音量设置",
|
||||
"hint": "音量设置,范围为 0.0 到 2.0,默认值为 1.0",
|
||||
},
|
||||
"azure_tts_voice": {
|
||||
"type": "string",
|
||||
"description": "音色设置",
|
||||
"hint": "API 音色",
|
||||
},
|
||||
"azure_tts_style": {
|
||||
"type": "string",
|
||||
"description": "风格设置",
|
||||
"hint": "声音特定的讲话风格。 可以表达快乐、同情和平静等情绪。",
|
||||
},
|
||||
"azure_tts_role": {
|
||||
"type": "string",
|
||||
"description": "模仿设置(可选)",
|
||||
"hint": "讲话角色扮演。 声音可以模仿不同的年龄和性别,但声音名称不会更改。 例如,男性语音可以提高音调和改变语调来模拟女性语音,但语音名称不会更改。 如果角色缺失或不受声音的支持,则会忽略此属性。",
|
||||
"options": [
|
||||
"Boy",
|
||||
"Girl",
|
||||
"YoungAdultFemale",
|
||||
"YoungAdultMale",
|
||||
"OlderAdultFemale",
|
||||
"OlderAdultMale",
|
||||
"SeniorFemale",
|
||||
"SeniorMale",
|
||||
"禁用",
|
||||
],
|
||||
},
|
||||
"azure_tts_rate": {
|
||||
"type": "string",
|
||||
"description": "语速设置",
|
||||
"hint": "指示文本的讲出速率。可在字词或句子层面应用语速。 速率变化应为原始音频的 0.5 到 2 倍。",
|
||||
},
|
||||
"azure_tts_volume": {
|
||||
"type": "string",
|
||||
"description": "语音音量设置",
|
||||
"hint": "指示语音的音量级别。 可在句子层面应用音量的变化。以从 0.0 到 100.0(从最安静到最大声,例如 75)的数字表示。 默认值为 100.0。",
|
||||
},
|
||||
"azure_tts_region": {
|
||||
"type": "string",
|
||||
"description": "API 地区",
|
||||
"hint": "Azure_TTS 处理数据所在区域,具体参考 https://learn.microsoft.com/zh-cn/azure/ai-services/speech-service/regions",
|
||||
"options": [
|
||||
"southafricanorth",
|
||||
"eastasia",
|
||||
"southeastasia",
|
||||
"australiaeast",
|
||||
"centralindia",
|
||||
"japaneast",
|
||||
"japanwest",
|
||||
"koreacentral",
|
||||
"canadacentral",
|
||||
"northeurope",
|
||||
"westeurope",
|
||||
"francecentral",
|
||||
"germanywestcentral",
|
||||
"norwayeast",
|
||||
"swedencentral",
|
||||
"switzerlandnorth",
|
||||
"switzerlandwest",
|
||||
"uksouth",
|
||||
"uaenorth",
|
||||
"brazilsouth",
|
||||
"qatarcentral",
|
||||
"centralus",
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"northcentralus",
|
||||
"southcentralus",
|
||||
"westcentralus",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
],
|
||||
},
|
||||
"azure_tts_subscription_key": {
|
||||
"type": "string",
|
||||
"description": "服务订阅密钥",
|
||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
||||
},
|
||||
"dashscope_tts_voice": {
|
||||
"description": "语音合成模型",
|
||||
"type": "string",
|
||||
@@ -726,6 +1039,12 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_url_context": {
|
||||
"description": "启用URL上下文功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
"type": "object",
|
||||
@@ -777,6 +1096,109 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"minimax-group-id": {
|
||||
"type": "string",
|
||||
"description": "用户组",
|
||||
"hint": "于账户管理->基本信息中可见",
|
||||
},
|
||||
"minimax-langboost": {
|
||||
"type": "string",
|
||||
"description": "指定语言/方言",
|
||||
"hint": "增强对指定的小语种和方言的识别能力,设置后可以提升在指定小语种/方言场景下的语音表现",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"Chinese,Yue",
|
||||
"English",
|
||||
"Arabic",
|
||||
"Russian",
|
||||
"Spanish",
|
||||
"French",
|
||||
"Portuguese",
|
||||
"German",
|
||||
"Turkish",
|
||||
"Dutch",
|
||||
"Ukrainian",
|
||||
"Vietnamese",
|
||||
"Indonesian",
|
||||
"Japanese",
|
||||
"Italian",
|
||||
"Korean",
|
||||
"Thai",
|
||||
"Polish",
|
||||
"Romanian",
|
||||
"Greek",
|
||||
"Czech",
|
||||
"Finnish",
|
||||
"Hindi",
|
||||
"auto",
|
||||
],
|
||||
},
|
||||
"minimax-voice-speed": {
|
||||
"type": "float",
|
||||
"description": "语速",
|
||||
"hint": "生成声音的语速, 取值[0.5, 2], 默认为1.0, 取值越大,语速越快",
|
||||
},
|
||||
"minimax-voice-vol": {
|
||||
"type": "float",
|
||||
"description": "音量",
|
||||
"hint": "生成声音的音量, 取值(0, 10], 默认为1.0, 取值越大,音量越高",
|
||||
},
|
||||
"minimax-voice-pitch": {
|
||||
"type": "int",
|
||||
"description": "语调",
|
||||
"hint": "生成声音的语调, 取值[-12, 12], 默认为0",
|
||||
},
|
||||
"minimax-is-timber-weight": {
|
||||
"type": "bool",
|
||||
"description": "启用混合音色",
|
||||
"hint": "启用混合音色, 支持以自定义权重混合最多四种音色, 启用后自动忽略单一音色设置",
|
||||
},
|
||||
"minimax-timber-weight": {
|
||||
"type": "string",
|
||||
"description": "混合音色",
|
||||
"editor_mode": True,
|
||||
"hint": "混合音色及其权重, 最多支持四种音色, 权重为整数, 取值[1, 100]. 可在官网API语音调试台预览代码获得预设以及编写模板, 需要严格按照json字符串格式编写, 可以查看控制台判断是否解析成功. 具体结构可参照默认值以及官网代码预览.",
|
||||
},
|
||||
"minimax-voice-id": {
|
||||
"type": "string",
|
||||
"description": "单一音色",
|
||||
"hint": "单一音色编号, 详见官网文档",
|
||||
},
|
||||
"minimax-voice-emotion": {
|
||||
"type": "string",
|
||||
"description": "情绪",
|
||||
"hint": "控制合成语音的情绪",
|
||||
"options": [
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"fearful",
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
],
|
||||
},
|
||||
"minimax-voice-latex": {
|
||||
"type": "bool",
|
||||
"description": "支持朗读latex公式",
|
||||
"hint": "朗读latex公式, 但是需要确保输入文本按官网要求格式化",
|
||||
},
|
||||
"minimax-voice-english-normalization": {
|
||||
"type": "bool",
|
||||
"description": "支持英语文本规范化",
|
||||
"hint": "可提升数字阅读场景的性能,但会略微增加延迟",
|
||||
},
|
||||
"rag_options": {
|
||||
"description": "RAG 选项",
|
||||
"type": "object",
|
||||
@@ -874,7 +1296,12 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商类型",
|
||||
"description": "模型提供商种类",
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
},
|
||||
"provider_type": {
|
||||
"description": "模型提供商能力种类",
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
},
|
||||
@@ -973,9 +1400,19 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"separate_provider": {
|
||||
"description": "提供商会话隔离",
|
||||
"type": "bool",
|
||||
"hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。",
|
||||
},
|
||||
"default_provider_id": {
|
||||
"description": "默认模型提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "可选。每个聊天会话的默认提供商 ID。",
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"type": "string",
|
||||
@@ -1088,7 +1525,7 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
"description": "提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
@@ -1105,7 +1542,7 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"description": "提供商 ID",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
@@ -1115,6 +1552,11 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"use_file_service": {
|
||||
"description": "使用文件服务提供 TTS 语音文件",
|
||||
"type": "bool",
|
||||
"hint": "启用后,如已配置 callback_api_base ,将会使用文件服务提供TTS语音文件",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -1227,6 +1669,12 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||
},
|
||||
"callback_api_base": {
|
||||
"description": "对外可达的回调接口地址",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
||||
},
|
||||
"log_level": {
|
||||
"description": "控制台日志级别",
|
||||
"type": "string",
|
||||
@@ -1244,6 +1692,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
|
||||
},
|
||||
"t2i_use_file_service": {
|
||||
"description": "本地文本转图像使用文件服务提供文件",
|
||||
"type": "bool",
|
||||
"hint": "当 t2i_strategy 为 local 并且配置 callback_api_base 时生效。是否使用文件服务提供文件。",
|
||||
},
|
||||
"pip_install_arg": {
|
||||
"description": "pip 安装参数",
|
||||
"type": "string",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
工作流程:
|
||||
@@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
@@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map
|
||||
class AstrBotCoreLifecycle:
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
@@ -54,7 +53,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
|
||||
# 初始化日志代理
|
||||
@@ -73,9 +72,6 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
@@ -87,7 +83,6 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
@@ -11,7 +11,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
|
||||
with open(
|
||||
os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8"
|
||||
) as f:
|
||||
sql = f.read()
|
||||
|
||||
# 初始化数据库
|
||||
|
||||
46
astrbot/core/db/vec_db/base.py
Normal file
46
astrbot/core/db/vec_db/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
similarity: float
|
||||
data: dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化向量数据库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
top_k (int): 返回的最相似文档的数量
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
删除指定文档。
|
||||
Args:
|
||||
doc_id (str): 要删除的文档 ID
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
Args:
|
||||
row (tuple): The row to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
"doc_id": row[1],
|
||||
"text": row[2],
|
||||
"metadata": row[3],
|
||||
"created_at": row[4],
|
||||
"updated_at": row[5],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
try:
|
||||
import faiss
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||
)
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
if path and os.path.exists(path):
|
||||
self.index = faiss.read_index(path)
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 要插入的向量
|
||||
id (int): 向量的ID
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
"""搜索最相似的向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 查询向量
|
||||
k (int): 返回的最相似向量的数量
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
Args:
|
||||
path (str): 保存索引的路径
|
||||
"""
|
||||
faiss.write_index(self.index, self.path)
|
||||
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
doc_id TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
|
||||
|
||||
CREATE INDEX idx_documents_user_id ON documents(user_id);
|
||||
CREATE INDEX idx_documents_group_id ON documents(group_id);
|
||||
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
A class to represent a vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
self.embedding_provider = embedding_provider
|
||||
self.document_storage = DocumentStorage(doc_store_path)
|
||||
self.embedding_storage = EmbeddingStorage(
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
embedding = await self.embedding_provider.get_embedding(query)
|
||||
scores, indices = await self.embedding_storage.search(
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
# TODO: rerank
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||
# NOTE: maybe the size is less than k.
|
||||
fetched_docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
pos = idx_pos.get(indice_idx)
|
||||
if pos is None:
|
||||
continue
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
return result_docs[:k]
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
删除一条文档
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
68
astrbot/core/file_token_service.py
Normal file
68
astrbot/core/file_token_service.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
||||
"""向令牌服务注册一个文件。
|
||||
|
||||
Args:
|
||||
file_path(str): 文件路径
|
||||
timeout(float): 超时时间,单位秒(可选)
|
||||
|
||||
Returns:
|
||||
str: 一个单次令牌
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
||||
self.staged_files[file_token] = (file_path, expire_time)
|
||||
return file_token
|
||||
|
||||
async def handle_file(self, file_token: str) -> str:
|
||||
"""根据令牌获取文件路径,使用后令牌失效。
|
||||
|
||||
Args:
|
||||
file_token(str): 注册时返回的令牌
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
|
||||
Raises:
|
||||
KeyError: 当令牌不存在或已过期时抛出
|
||||
FileNotFoundError: 当文件本身已被删除时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if file_token not in self.staged_files:
|
||||
raise KeyError(f"无效或过期的文件 token: {file_token}")
|
||||
|
||||
file_path, _ = self.staged_files.pop(file_token)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
return file_path
|
||||
@@ -26,13 +26,14 @@ class InitialLoader:
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||
return
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
|
||||
@@ -22,14 +22,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import typing as T
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.v1 import BaseModel
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||
|
||||
from astrbot.core import astrbot_config, file_token_service, logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
@@ -97,6 +102,10 @@ class BaseMessageComponent(BaseModel):
|
||||
data[k] = v
|
||||
return {"type": self.type.lower(), "data": data}
|
||||
|
||||
async def to_dict(self) -> dict:
|
||||
# 默认情况下,回退到旧的同步 toDict()
|
||||
return self.toDict()
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
@@ -113,6 +122,9 @@ class Plain(BaseMessageComponent):
|
||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
)
|
||||
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
@@ -165,7 +177,8 @@ class Record(BaseMessageComponent):
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
@@ -196,6 +209,29 @@ class Record(BaseMessageComponent):
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将语音注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Video(BaseMessageComponent):
|
||||
type: ComponentType = "Video"
|
||||
@@ -206,9 +242,6 @@ class Video(BaseMessageComponent):
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "c" and _[k] not in [2, 3]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -221,6 +254,70 @@ class Video(BaseMessageComponent):
|
||||
return Video(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
async def convert_to_file_path(self) -> str:
|
||||
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。
|
||||
|
||||
Returns:
|
||||
str: 视频的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.file
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url and url.startswith("http"):
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
else:
|
||||
raise Exception(f"download failed: {url}")
|
||||
elif os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = self.file
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated video file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
@@ -230,6 +327,12 @@ class At(BaseMessageComponent):
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
def toDict(self):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {"qq": str(self.qq)},
|
||||
}
|
||||
|
||||
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
@@ -369,7 +472,8 @@ class Image(BaseMessageComponent):
|
||||
elif url and url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
@@ -401,23 +505,44 @@ class Image(BaseMessageComponent):
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将图片注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
id: T.Union[str, int]
|
||||
"""所引用的消息 ID"""
|
||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||
"""引用的消息段列表"""
|
||||
"""被引用的消息段列表"""
|
||||
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||
"""引用的消息发送者 ID"""
|
||||
"""被引用的消息对应的发送者的 ID"""
|
||||
sender_nickname: T.Optional[str] = ""
|
||||
"""引用的消息发送者昵称"""
|
||||
"""被引用的消息对应的发送者的昵称"""
|
||||
time: T.Optional[int] = 0
|
||||
"""引用的消息发送时间"""
|
||||
"""被引用的消息发送时间"""
|
||||
message_str: T.Optional[str] = ""
|
||||
"""解析后的纯文本消息字符串"""
|
||||
sender_str: T.Optional[str] = ""
|
||||
"""被引用的消息纯文本"""
|
||||
"""被引用的消息解析后的纯文本消息字符串"""
|
||||
|
||||
text: T.Optional[str] = ""
|
||||
"""deprecated"""
|
||||
@@ -462,28 +587,48 @@ class Node(BaseMessageComponent):
|
||||
type: ComponentType = "Node"
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[int] = 0 # qq号
|
||||
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
content: T.Optional[list[BaseMessageComponent]] = []
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0 # 忽略
|
||||
|
||||
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||
if isinstance(content, list):
|
||||
_content = None
|
||||
if all(isinstance(item, Node) for item in content):
|
||||
_content = [node.toDict() for node in content]
|
||||
else:
|
||||
_content = ""
|
||||
for chain in content:
|
||||
_content += chain.toString()
|
||||
content = _content
|
||||
elif isinstance(content, Node):
|
||||
content = content.toDict()
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
if isinstance(content, Node):
|
||||
# back
|
||||
content = [content]
|
||||
super().__init__(content=content, **_)
|
||||
|
||||
def toString(self):
|
||||
# logger.warn("Protocol: node doesn't support stringify")
|
||||
return ""
|
||||
async def to_dict(self):
|
||||
data_content = []
|
||||
for comp in self.content:
|
||||
if isinstance(comp, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await comp.convert_to_base64()
|
||||
data_content.append(
|
||||
{
|
||||
"type": comp.type.lower(),
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, (Node, Nodes)):
|
||||
# For Node segments, we recursively convert them to dict
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
else:
|
||||
d = comp.toDict()
|
||||
data_content.append(d)
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": str(self.uin),
|
||||
"nickname": self.name,
|
||||
"content": data_content,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
@@ -494,7 +639,22 @@ class Nodes(BaseMessageComponent):
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
return {"messages": [node.toDict() for node in self.nodes]}
|
||||
"""Deprecated. Use to_dict instead"""
|
||||
ret = {
|
||||
"messages": [],
|
||||
}
|
||||
for node in self.nodes:
|
||||
d = node.toDict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
d = await node.to_dict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
@@ -554,15 +714,136 @@ class Unknown(BaseMessageComponent):
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
"""
|
||||
目前此消息段只适配了 Napcat。
|
||||
文件消息段
|
||||
"""
|
||||
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
file_: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||
"""文件消息段。"""
|
||||
super().__init__(name=name, file_=file, url=url)
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
"""
|
||||
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
"""
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
||||
)
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
@file.setter
|
||||
def file(self, value: str):
|
||||
"""
|
||||
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
|
||||
Args:
|
||||
value (str): 文件路径或URL
|
||||
"""
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
self.url = value
|
||||
else:
|
||||
self.file_ = value
|
||||
|
||||
async def get_file(self, allow_return_url: bool = False) -> str:
|
||||
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
|
||||
|
||||
Args:
|
||||
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
|
||||
注意,如果为 True,也可能返回文件路径。
|
||||
Returns:
|
||||
str: 文件路径或者 http 下载链接
|
||||
"""
|
||||
if allow_return_url and self.url:
|
||||
return self.url
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将文件注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.get_file()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = await self.get_file(allow_return_url=True)
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {
|
||||
"name": self.name,
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
|
||||
@@ -43,31 +43,31 @@ class PreProcessStage(Stage):
|
||||
# STT
|
||||
if self.stt_settings.get("enable", False):
|
||||
# TODO: 独立
|
||||
stt_provider = (
|
||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
)
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
|
||||
@@ -26,6 +26,14 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -60,15 +68,19 @@ class LLMRequestSubStage(Stage):
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
umo = event.unified_msg_origin
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(
|
||||
req, ProviderRequest
|
||||
), "provider_request 必须是 ProviderRequest 类型。"
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
@@ -149,7 +161,14 @@ class LLMRequestSubStage(Stage):
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next((i for i, item in enumerate(req.contexts) if item.get("role") == "user"), None)
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(req.contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
req.contexts = req.contexts[index:]
|
||||
|
||||
@@ -265,6 +284,71 @@ class LLMRequestSubStage(Stage):
|
||||
event.set_extra("tool_call_result", None)
|
||||
yield
|
||||
|
||||
# 暂时直接发出去
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
|
||||
if event.get_platform_name() == "webchat":
|
||||
# 异步处理 WebChat 特殊情况
|
||||
asyncio.create_task(self._handle_webchat(event, req))
|
||||
|
||||
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
if conversation and not req.conversation.title:
|
||||
messages = json.loads(conversation.history)
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
# if len(latest_pair) > 1:
|
||||
# cleaned_text += (
|
||||
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
|
||||
# )
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await provider.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
prompt=(
|
||||
f"Please summarize the following query of user:\n"
|
||||
f"{cleaned_text}\n"
|
||||
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
||||
"You must use the same language as the user."
|
||||
"If you think the dialog is too short to summarize, only output a special mark: `None`"
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
logger.debug(
|
||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
||||
)
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "None" == title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
event.unified_msg_origin, title=title
|
||||
)
|
||||
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
|
||||
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
|
||||
# TODO: 优化 WebChat 适配器的对话管理
|
||||
if event.session_id:
|
||||
username, cid = event.session_id.split("!")[1:3]
|
||||
db_helper = self.ctx.plugin_manager.context._db
|
||||
db_helper.update_conversation_title(
|
||||
user_id=username,
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
web_chat_back_queue.put_nowait(
|
||||
{
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_llm_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -375,21 +459,68 @@ class LLMRequestSubStage(Stage):
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if res:
|
||||
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
else:
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取处理器,过滤掉平台不兼容的处理器
|
||||
platform_id = event.get_platform_id()
|
||||
star_md = star_map.get(func_tool.handler_module_path)
|
||||
if (
|
||||
star_md and
|
||||
platform_id in star_md.supported_platforms
|
||||
star_md
|
||||
and platform_id in star_md.supported_platforms
|
||||
and not star_md.supported_platforms[platform_id]
|
||||
):
|
||||
logger.debug(
|
||||
|
||||
@@ -58,33 +58,30 @@ class RateLimitStage(Stage):
|
||||
now = datetime.now()
|
||||
|
||||
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
# 检查并处理限流,可能需要多次检查直到满足条件
|
||||
while True:
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
if len(timestamps) < self.rate_limit_count:
|
||||
timestamps.append(now)
|
||||
break
|
||||
else:
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
||||
|
||||
if len(timestamps) >= self.rate_limit_count:
|
||||
# 达到限流阈值,计算下一个窗口的时间
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(
|
||||
timestamps, now + timedelta(seconds=stall_duration)
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
now = datetime.now()
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
def _remove_expired_timestamps(
|
||||
self, timestamps: Deque[datetime], now: datetime
|
||||
|
||||
@@ -26,33 +26,13 @@ class RespondStage(Stage):
|
||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||
Comp.AtAll: lambda comp: True, # @所有人
|
||||
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||
Comp.Music: lambda comp: bool(comp._type)
|
||||
and bool(comp.url)
|
||||
and bool(comp.audio), # 音乐
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||
Comp.Node: lambda comp: bool(comp.name)
|
||||
and comp.uin != 0
|
||||
and bool(comp.content), # 一个转发节点
|
||||
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
@@ -129,8 +109,6 @@ class RespondStage(Stage):
|
||||
if comp_type in self._component_validators:
|
||||
if self._component_validators[comp_type](comp):
|
||||
return False
|
||||
else:
|
||||
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
@@ -175,6 +153,11 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||
non_record_comps = [
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
]
|
||||
|
||||
if self.enable_seg and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
@@ -192,21 +175,39 @@ class RespondStage(Stage):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
try:
|
||||
await event.send(result)
|
||||
await event.send(MessageChain(non_record_comps))
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -168,30 +169,55 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
):
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(
|
||||
Record(file=audio_path, url=audio_path)
|
||||
)
|
||||
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
@@ -225,6 +251,14 @@ class ResultDecorateStage(Stage):
|
||||
if url:
|
||||
if url.startswith("http"):
|
||||
result.chain = [Image.fromURL(url)]
|
||||
elif (
|
||||
self.ctx.astrbot_config["t2i_use_file_service"]
|
||||
and self.ctx.astrbot_config["callback_api_base"]
|
||||
):
|
||||
token = await file_token_service.register_file(url)
|
||||
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
result.chain = [Image.fromURL(url)]
|
||||
else:
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.message.components import At, AtAll
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
@@ -39,6 +39,9 @@ class WakingCheckStage(Stage):
|
||||
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_bot_self_message", False
|
||||
)
|
||||
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_at_all", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -79,10 +82,9 @@ class WakingCheckStage(Stage):
|
||||
if not is_wake:
|
||||
# 检查是否有 at 消息
|
||||
for message in messages:
|
||||
if isinstance(message, At) and (
|
||||
if (isinstance(message, At) and (
|
||||
str(message.qq) == str(event.get_self_id())
|
||||
or str(message.qq) == "all"
|
||||
):
|
||||
)) or (isinstance(message, AtAll) and not self.ignore_at_all):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
wake_prefix = ""
|
||||
@@ -137,7 +139,7 @@ class WakingCheckStage(Stage):
|
||||
if self.no_permission_reply:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
||||
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
|
||||
@@ -62,6 +62,10 @@ class PlatformManager:
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
)
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
case "dingtalk":
|
||||
@@ -72,6 +76,8 @@ class PlatformManager:
|
||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||
case "wecom":
|
||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||
|
||||
@@ -3,7 +3,16 @@ import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Node,
|
||||
Nodes,
|
||||
Plain,
|
||||
Record,
|
||||
Video,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
|
||||
|
||||
@@ -14,44 +23,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
|
||||
"""修复部分字段"""
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
return {
|
||||
"type": segment.type.lower(),
|
||||
"data": {
|
||||
"file": f"base64://{bs64}",
|
||||
},
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
elif isinstance(segment, Video):
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
else:
|
||||
# For other segments, we simply convert them to a dict by calling toDict
|
||||
return segment.toDict()
|
||||
|
||||
@staticmethod
|
||||
async def _parse_onebot_json(message_chain: MessageChain):
|
||||
"""解析成 OneBot json 格式"""
|
||||
ret = []
|
||||
for segment in message_chain.chain:
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d["type"] = "text"
|
||||
d["data"]["text"] = segment.text.strip()
|
||||
# 如果是空文本或者只带换行符的文本,不发送
|
||||
if not d["data"]["text"]:
|
||||
if not segment.text.strip():
|
||||
continue
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
d["data"] = {
|
||||
"file": f"base64://{bs64}",
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d["data"] = {
|
||||
"qq": str(segment.qq) # 转换为字符串
|
||||
}
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
|
||||
if not ret:
|
||||
return
|
||||
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 转发消息不能和普通消息混在一起发送
|
||||
send_one_by_one = True
|
||||
break
|
||||
|
||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||
send_one_by_one = any(
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
||||
)
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
@@ -61,7 +72,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = seg.toDict()
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
@@ -70,6 +82,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
await self.bot.call_action(
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
elif isinstance(seg, File):
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
[d],
|
||||
)
|
||||
else:
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
@@ -79,6 +97,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if not ret:
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import itertools
|
||||
from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import (
|
||||
@@ -20,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -45,7 +44,12 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
||||
use_ws_reverse=True,
|
||||
import_name="aiocqhttp",
|
||||
api_timeout_sec=180,
|
||||
access_token=platform_config.get(
|
||||
"ws_reverse_token"
|
||||
), # 以防旧版本配置不存在
|
||||
)
|
||||
|
||||
@self.bot.on_request()
|
||||
@@ -99,6 +103,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if event["post_type"] == "message":
|
||||
abm = await self._convert_handle_message_event(event)
|
||||
if abm.sender.user_id == "2854196310":
|
||||
# 屏蔽 QQ 管家的消息
|
||||
return
|
||||
elif event["post_type"] == "notice":
|
||||
abm = await self._convert_handle_notice_event(event)
|
||||
elif event["post_type"] == "request":
|
||||
@@ -119,6 +126,12 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -155,7 +168,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if "sub_type" in event:
|
||||
if event["sub_type"] == "poke" and "target_id" in event:
|
||||
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
||||
abm.message.append(
|
||||
Poke(qq=str(event["target_id"]), type="poke")
|
||||
) # noqa: F405
|
||||
|
||||
return abm
|
||||
|
||||
@@ -202,82 +217,122 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for m in event.message:
|
||||
t = m["type"]
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
a = None
|
||||
if t == "text":
|
||||
message_str += m["data"]["text"].strip()
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
current_text = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
if not current_text:
|
||||
# 如果文本段为空,则跳过
|
||||
continue
|
||||
message_str += current_text
|
||||
a = ComponentTypes[t](text=current_text) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
elif t == "file":
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
for m in m_group:
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
# Napcat
|
||||
ret = None
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_group_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
group_id=event.group_id,
|
||||
)
|
||||
elif abm.type == MessageType.FRIEND_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_private_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
path = os.path.join("data/temp", file_name)
|
||||
await download_file(m["data"]["url"], path)
|
||||
|
||||
m["data"] = {"file": path, "name": file_name}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
else:
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(
|
||||
action="get_file",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if not ret.get("file", None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret["file"]):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
|
||||
)
|
||||
|
||||
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
elif t == "reply":
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
for m in m_group:
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
elif t == "at":
|
||||
first_at_self_processed = False
|
||||
|
||||
for m in m_group:
|
||||
try:
|
||||
if m["data"]["qq"] == "all":
|
||||
abm.message.append(At(qq="all", name="全体成员"))
|
||||
continue
|
||||
|
||||
at_info = await self.bot.call_action(
|
||||
action="get_stranger_info",
|
||||
user_id=int(m["data"]["qq"]),
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("nick", "")
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
At(
|
||||
qq=m["data"]["qq"],
|
||||
name=nickname,
|
||||
)
|
||||
)
|
||||
|
||||
if is_at_self and not first_at_self_processed:
|
||||
# 第一个@是机器人,不添加到message_str
|
||||
first_at_self_processed = True
|
||||
else:
|
||||
# 非第一个@机器人或@其他用户,添加到message_str
|
||||
message_str += f" @{nickname} "
|
||||
else:
|
||||
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
else:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -19,6 +20,7 @@ from ...register import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from dingtalk_stream import AckMessage
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
@@ -152,7 +154,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
"downloadCode": download_code,
|
||||
"robotCode": robot_code,
|
||||
}
|
||||
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
@@ -14,6 +15,7 @@ from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
@@ -63,7 +65,7 @@ class SimpleGewechatClient:
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_id>",
|
||||
"/astrbot-gewechat/file/<file_token>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
@@ -81,6 +83,11 @@ class SimpleGewechatClient:
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self.staged_files = {}
|
||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -143,18 +150,25 @@ class SimpleGewechatClient:
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
at_wxids = []
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
at_wxids = re.findall(
|
||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||
msg_source,
|
||||
)
|
||||
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
@@ -167,13 +181,12 @@ class SimpleGewechatClient:
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
if user_id == abm.self_id:
|
||||
logger.info("忽略自己发送的消息")
|
||||
return None
|
||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||
# if user_id == abm.self_id:
|
||||
# logger.info("忽略自己发送的消息")
|
||||
# return None
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
@@ -197,7 +210,19 @@ class SimpleGewechatClient:
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0]
|
||||
try:
|
||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||
user_real_name = info["nickName"]
|
||||
except Exception as e:
|
||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||
user_real_name = user_id
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||
for wxid in at_wxids:
|
||||
# 群聊里 At 其他人的列表
|
||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||
abm.message.append(At(qq=wxid, name=_username))
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
@@ -226,7 +251,10 @@ class SimpleGewechatClient:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
@@ -248,9 +276,12 @@ class SimpleGewechatClient:
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
abm_data = data_parser.parse_mutil_49()
|
||||
if abm_data:
|
||||
abm.message.append(abm_data)
|
||||
segments = data_parser.parse_mutil_49()
|
||||
if segments:
|
||||
abm.message.extend(segments)
|
||||
for seg in segments:
|
||||
if isinstance(seg, Plain):
|
||||
abm.message_str += seg.text
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
@@ -289,9 +320,33 @@ class SimpleGewechatClient:
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
async def _register_file(self, file_path: str) -> str:
|
||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
Returns:
|
||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||
"""
|
||||
async with self.lock:
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
self.staged_files[file_token] = file_path
|
||||
return file_token
|
||||
|
||||
async def _handle_file(self, file_token):
|
||||
async with self.lock:
|
||||
if file_token not in self.staged_files:
|
||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||
return quart.abort(404)
|
||||
if not os.path.exists(self.staged_files[file_token]):
|
||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||
return quart.abort(404)
|
||||
file_path = self.staged_files[file_token]
|
||||
self.staged_files.pop(file_token, None)
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
@@ -407,8 +462,10 @@ class SimpleGewechatClient:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
if os.path.exists("data/temp/gewe_code"):
|
||||
with open("data/temp/gewe_code", "r") as f:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
@@ -419,9 +476,9 @@ class SimpleGewechatClient:
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
@@ -441,17 +498,18 @@ class SimpleGewechatClient:
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
if "status" in json_blob["data"]:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
|
||||
@@ -6,7 +6,7 @@ import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -21,6 +21,7 @@ from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
@@ -83,15 +84,9 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中
|
||||
temp_directory = os.path.abspath("data/temp")
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{client.file_server_url}/{file_id}"
|
||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||
token = await client._register_file(img_path)
|
||||
img_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
@@ -110,20 +105,33 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
await download_file(video_url, video_path)
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
|
||||
video_token = await client._register_file(video_path)
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
|
||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||
ff.options(command)
|
||||
thumb_file_id = os.path.basename(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
|
||||
thumb_token = await client._register_file(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
@@ -138,15 +146,12 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
file_id = os.path.basename(video_path)
|
||||
video_url = f"{client.file_server_url}/{file_id}"
|
||||
# 发送视频
|
||||
await client.post_video(
|
||||
to_wxid, video_url, thumb_url, video_duration
|
||||
to_wxid, video_callback_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时视频和缩略图文件
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
# 删除临时缩略图文件
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
@@ -154,7 +159,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
@@ -163,8 +169,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{client.file_server_url}/{file_id}"
|
||||
token = await client._register_file(silk_path)
|
||||
record_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
@@ -173,14 +179,17 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
file_id = os.path.basename(file_path)
|
||||
file_url = f"{client.file_server_url}/{file_id}"
|
||||
token = await client._register_file(file_path)
|
||||
file_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_id)
|
||||
await client.post_file(to_wxid, file_url, file_name)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Reply,
|
||||
Plain,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
@@ -11,7 +16,7 @@ class GeweDataParser:
|
||||
def _format_to_xml(self):
|
||||
return eT.fromstring(self.data)
|
||||
|
||||
def parse_mutil_49(self):
|
||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||
if appmsg_type is None:
|
||||
return
|
||||
@@ -34,13 +39,18 @@ class GeweDataParser:
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||
|
||||
def parse_reply(self) -> Reply | None:
|
||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||
"""解析引用消息
|
||||
|
||||
Returns:
|
||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||
"""
|
||||
try:
|
||||
replied_id = -1
|
||||
replied_uid = 0
|
||||
replied_nickname = ""
|
||||
replied_content = ""
|
||||
content = ""
|
||||
replied_content = "" # 被引用者说的内容
|
||||
content = "" # 引用者说的内容
|
||||
|
||||
root = self._format_to_xml()
|
||||
refermsg = root.find(".//refermsg")
|
||||
@@ -57,22 +67,44 @@ class GeweDataParser:
|
||||
if displayname is not None:
|
||||
replied_nickname = displayname.text
|
||||
if refermsg_content is not None:
|
||||
replied_content = refermsg_content.text
|
||||
# 处理引用嵌套,包括嵌套公众号消息
|
||||
if refermsg_content.text.startswith(
|
||||
"<msg>"
|
||||
) or refermsg_content.text.startswith("<?xml"):
|
||||
try:
|
||||
logger.debug("gewechat: Reference message is nested")
|
||||
refer_root = eT.fromstring(refermsg_content.text)
|
||||
img = refer_root.find("img")
|
||||
if img is not None:
|
||||
replied_content = "[图片]"
|
||||
else:
|
||||
app_msg = refer_root.find("appmsg")
|
||||
refermsg_content_title = app_msg.find("title")
|
||||
logger.debug(
|
||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||
)
|
||||
replied_content = refermsg_content_title.text
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: nested failed, {e}")
|
||||
# 处理异常情况
|
||||
replied_content = refermsg_content.text
|
||||
else:
|
||||
replied_content = refermsg_content.text
|
||||
|
||||
# 提取引用者说的内容
|
||||
title = root.find(".//appmsg/title")
|
||||
if title is not None:
|
||||
content = title.text
|
||||
|
||||
r = Reply(
|
||||
reply_seg = Reply(
|
||||
id=replied_id,
|
||||
chain=[Plain(content)],
|
||||
chain=[Plain(replied_content)],
|
||||
sender_id=replied_uid,
|
||||
sender_nickname=replied_nickname,
|
||||
sender_str=replied_content,
|
||||
message_str=content,
|
||||
message_str=replied_content,
|
||||
)
|
||||
return r
|
||||
plain_seg = Plain(content)
|
||||
return [reply_seg, plain_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
import lark_oapi as lark
|
||||
@@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -40,7 +42,8 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
# save as temp file
|
||||
file_path = f"data/temp/{uuid.uuid4()}_test.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
|
||||
@@ -58,6 +58,14 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
self.enable_command_register = self.config.get(
|
||||
"telegram_command_register", True
|
||||
)
|
||||
self.enable_command_refresh = self.config.get(
|
||||
"telegram_command_auto_refresh", True
|
||||
)
|
||||
self.last_command_hash = None
|
||||
|
||||
self.application = (
|
||||
ApplicationBuilder()
|
||||
.token(self.config["telegram_token"])
|
||||
@@ -95,17 +103,19 @@ class TelegramPlatformAdapter(Platform):
|
||||
async def run(self):
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
await self.register_commands()
|
||||
|
||||
# TODO 使用更优雅的方式重新注册命令
|
||||
self.scheduler.add_job(
|
||||
self.register_commands,
|
||||
"interval",
|
||||
minutes=5,
|
||||
id="telegram_command_register",
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
self.scheduler.start()
|
||||
if self.enable_command_register:
|
||||
await self.register_commands()
|
||||
|
||||
if self.enable_command_refresh and self.enable_command_register:
|
||||
self.scheduler.add_job(
|
||||
self.register_commands,
|
||||
"interval",
|
||||
seconds=self.config.get("telegram_command_register_interval", 300),
|
||||
id="telegram_command_register",
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
@@ -114,10 +124,16 @@ class TelegramPlatformAdapter(Platform):
|
||||
async def register_commands(self):
|
||||
"""收集所有注册的指令并注册到 Telegram"""
|
||||
try:
|
||||
await self.client.delete_my_commands()
|
||||
commands = self.collect_commands()
|
||||
|
||||
if commands:
|
||||
current_hash = hash(
|
||||
tuple((cmd.command, cmd.description) for cmd in commands)
|
||||
)
|
||||
if current_hash == self.last_command_hash:
|
||||
return
|
||||
self.last_command_hash = current_hash
|
||||
await self.client.delete_my_commands()
|
||||
await self.client.set_my_commands(commands)
|
||||
|
||||
except Exception as e:
|
||||
@@ -128,8 +144,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
command_dict = {}
|
||||
skip_commands = {"start"}
|
||||
|
||||
for handler_md in star_handlers_registry._handlers:
|
||||
handler_metadata = handler_md[1]
|
||||
for handler_md in star_handlers_registry:
|
||||
handler_metadata = handler_md
|
||||
if not star_map[handler_metadata.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_metadata.event_filters:
|
||||
@@ -266,10 +282,12 @@ class TelegramPlatformAdapter(Platform):
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
plain_text = (
|
||||
plain_text[: entity.offset]
|
||||
+ plain_text[entity.offset + entity.length :]
|
||||
)
|
||||
# 如果mention是当前bot则移除;否则保留
|
||||
if name.lower() == context.bot.username.lower():
|
||||
plain_text = (
|
||||
plain_text[: entity.offset]
|
||||
+ plain_text[entity.offset + entity.length :]
|
||||
)
|
||||
|
||||
if plain_text:
|
||||
message.message.append(Comp.Plain(plain_text))
|
||||
@@ -342,7 +360,9 @@ class TelegramPlatformAdapter(Platform):
|
||||
self.scheduler.shutdown()
|
||||
|
||||
await self.application.stop()
|
||||
await self.client.delete_my_commands()
|
||||
|
||||
if self.enable_command_register:
|
||||
await self.client.delete_my_commands()
|
||||
|
||||
# 保险起见先判断是否存在updater对象
|
||||
if self.application.updater is not None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -13,9 +15,20 @@ from astrbot.api.message_components import (
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
# Telegram 的最大消息长度限制
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
SPLIT_PATTERNS = {
|
||||
"paragraph": re.compile(r"\n\n"),
|
||||
"line": re.compile(r"\n"),
|
||||
"sentence": re.compile(r"[.!?。!?]"),
|
||||
"word": re.compile(r"\s"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
@@ -27,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
|
||||
def _split_message(self, text: str) -> list[str]:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
chunks.append(text)
|
||||
break
|
||||
|
||||
split_point = self.MAX_MESSAGE_LENGTH
|
||||
segment = text[: self.MAX_MESSAGE_LENGTH]
|
||||
|
||||
for _, pattern in self.SPLIT_PATTERNS.items():
|
||||
if matches := list(pattern.finditer(segment)):
|
||||
last_match = matches[-1]
|
||||
split_point = last_match.end()
|
||||
break
|
||||
|
||||
chunks.append(text[:split_point])
|
||||
text = text[split_point:].lstrip()
|
||||
|
||||
return chunks
|
||||
|
||||
async def send_with_client(
|
||||
self, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
@@ -57,25 +95,29 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
if isinstance(i, Plain):
|
||||
if at_user_id and not at_flag:
|
||||
i.text = f"@{at_user_id} " + i.text
|
||||
i.text = f"@{at_user_id} {i.text}"
|
||||
at_flag = True
|
||||
text = i.text
|
||||
try:
|
||||
text = telegramify_markdown.markdownify(
|
||||
i.text, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||
)
|
||||
return
|
||||
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||
chunks = self._split_message(i.text)
|
||||
for chunk in chunks:
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text, parse_mode="MarkdownV2", **payload
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead."
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -126,7 +168,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -143,17 +186,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
|
||||
# Plain
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
else:
|
||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
@@ -172,6 +205,18 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
delta = ""
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
|
||||
try:
|
||||
if delta and current_content != delta:
|
||||
|
||||
@@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class QueueListener:
|
||||
@@ -40,7 +41,8 @@ class WebChatAdapter(Platform):
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id")
|
||||
|
||||
@@ -6,8 +6,9 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
imgs_dir = "data/webchat/imgs"
|
||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
@@ -0,0 +1,887 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import websockets
|
||||
from astrbot import logger
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api.platform import Platform, PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .wechatpadpro_message_event import WeChatPadProMessageEvent
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||
class WeChatPadProAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wechatpadpro",
|
||||
description="WeChatPadPro 消息平台适配器",
|
||||
id=self.config.get("id", "wechatpadpro"),
|
||||
)
|
||||
|
||||
# 保存配置信息
|
||||
self.admin_key = self.config.get("admin_key")
|
||||
self.host = self.config.get("host")
|
||||
self.port = self.config.get("port")
|
||||
self.active_mesasge_poll: bool = self.config.get(
|
||||
"wpp_active_message_poll", False
|
||||
)
|
||||
self.active_message_poll_interval: int = self.config.get(
|
||||
"wpp_active_message_poll_interval", 5
|
||||
)
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(), "wechatpadpro_credentials.json"
|
||||
) # 持久化文件路径
|
||||
self.ws_handle_task = None
|
||||
|
||||
# 添加图片消息缓存,用于引用消息处理
|
||||
self.cached_images = {}
|
||||
"""缓存图片消息。key是NewMsgId (对应引用消息的svrid),value是图片的base64数据"""
|
||||
# 设置缓存大小限制,避免内存占用过大
|
||||
self.max_image_cache = 50
|
||||
|
||||
# 添加文本消息缓存,用于引用消息处理
|
||||
self.cached_texts = {}
|
||||
"""缓存文本消息。key是NewMsgId (对应引用消息的svrid),value是消息文本内容"""
|
||||
# 设置文本缓存大小限制
|
||||
self.max_text_cache = 100
|
||||
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
启动平台适配器的运行实例。
|
||||
"""
|
||||
logger.info("WeChatPadPro 适配器正在启动...")
|
||||
|
||||
if loaded_credentials := self.load_credentials():
|
||||
self.auth_key = loaded_credentials.get("auth_key")
|
||||
self.wxid = loaded_credentials.get("wxid")
|
||||
|
||||
isLoginIn = await self.check_online_status()
|
||||
|
||||
# 检查在线状态
|
||||
if self.auth_key and isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
|
||||
# 如果在线,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
else:
|
||||
# 1. 生成授权码
|
||||
if not self.auth_key:
|
||||
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
|
||||
await self.generate_auth_key()
|
||||
|
||||
# 2. 获取登录二维码
|
||||
if not isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
|
||||
qr_code_url = await self.get_login_qr_code()
|
||||
|
||||
if qr_code_url:
|
||||
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
|
||||
else:
|
||||
logger.error("无法获取登录二维码。")
|
||||
return
|
||||
|
||||
# 3. 检测扫码状态
|
||||
login_successful = await self.check_login_status()
|
||||
|
||||
if login_successful:
|
||||
logger.info("登录成功,WeChatPadPro适配器已连接。")
|
||||
else:
|
||||
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
|
||||
await self.terminate()
|
||||
return
|
||||
|
||||
# 登录成功后,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
|
||||
self._shutdown_event = asyncio.Event()
|
||||
await self._shutdown_event.wait()
|
||||
logger.info("WeChatPadPro 适配器已停止。")
|
||||
|
||||
def load_credentials(self):
|
||||
"""
|
||||
从文件中加载 auth_key 和 wxid。
|
||||
"""
|
||||
if os.path.exists(self.credentials_file):
|
||||
try:
|
||||
with open(self.credentials_file, "r") as f:
|
||||
credentials = json.load(f)
|
||||
logger.info("成功加载 WeChatPadPro 凭据。")
|
||||
return credentials
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WeChatPadPro 凭据失败: {e}")
|
||||
return None
|
||||
|
||||
def save_credentials(self):
|
||||
"""
|
||||
将 auth_key 和 wxid 保存到文件。
|
||||
"""
|
||||
credentials = {
|
||||
"auth_key": self.auth_key,
|
||||
"wxid": self.wxid,
|
||||
}
|
||||
try:
|
||||
# 确保数据目录存在
|
||||
data_dir = os.path.dirname(self.credentials_file)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
with open(self.credentials_file, "w") as f:
|
||||
json.dump(credentials, f)
|
||||
logger.info("成功保存 WeChatPadPro 凭据。")
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
|
||||
|
||||
async def check_online_status(self):
|
||||
"""
|
||||
检查 WeChatPadPro 设备是否在线。
|
||||
"""
|
||||
url = f"{self.base_url}/login/GetLoginStatus"
|
||||
params = {"key": self.auth_key}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response_data = await response.json()
|
||||
# 根据提供的在线接口返回示例,成功状态码是 200,loginState 为 1 表示在线
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
login_state = response_data.get("Data", {}).get("loginState")
|
||||
if login_state == 1:
|
||||
logger.info("WeChatPadPro 设备当前在线。")
|
||||
return True
|
||||
# login_state == 3 为离线状态
|
||||
elif login_state == 3:
|
||||
logger.info("WeChatPadPro 设备不在线。")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"未知的在线状态: {login_state:}")
|
||||
return False
|
||||
# Code == 300 为微信退出状态。
|
||||
elif response.status == 200 and response_data.get("Code") == 300:
|
||||
logger.info("WeChatPadPro 设备已退出。")
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检查在线状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
return False
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态时发生错误: {e}")
|
||||
return False
|
||||
|
||||
async def generate_auth_key(self):
|
||||
"""
|
||||
生成授权码。
|
||||
"""
|
||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||
params = {"key": self.admin_key}
|
||||
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
# 修正成功判断条件和授权码提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 授权码在 Data 字段的列表中
|
||||
if (
|
||||
response_data.get("Data")
|
||||
and isinstance(response_data["Data"], list)
|
||||
and len(response_data["Data"]) > 0
|
||||
):
|
||||
self.auth_key = response_data["Data"][0]
|
||||
logger.info("成功获取授权码")
|
||||
else:
|
||||
logger.error(
|
||||
f"生成授权码成功但未找到授权码: {response_data}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"生成授权码失败: {response.status}, {response_data}"
|
||||
)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"生成授权码时发生错误: {e}")
|
||||
|
||||
async def get_login_qr_code(self):
|
||||
"""
|
||||
获取登录二维码地址。
|
||||
"""
|
||||
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {} # 根据文档,这个接口的 body 可以为空
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
# 修正成功判断条件和数据提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 二维码地址在 Data.QrCodeUrl 字段中
|
||||
if response_data.get("Data") and response_data["Data"].get(
|
||||
"QrCodeUrl"
|
||||
):
|
||||
return response_data["Data"]["QrCodeUrl"]
|
||||
else:
|
||||
logger.error(
|
||||
f"获取登录二维码成功但未找到二维码地址: {response_data}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
logger.error(
|
||||
f"获取登录二维码失败: {response.status}, {response_data}"
|
||||
)
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取登录二维码时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def check_login_status(self):
|
||||
"""
|
||||
循环检测扫码状态。
|
||||
尝试 6 次后跳出循环,添加倒计时。
|
||||
返回 True 如果登录成功,否则返回 False。
|
||||
"""
|
||||
url = f"{self.base_url}/login/CheckLoginStatus"
|
||||
params = {"key": self.auth_key}
|
||||
|
||||
attempts = 0 # 初始化尝试次数
|
||||
max_attempts = 36 # 最大尝试次数
|
||||
countdown = 180 # 倒计时时长
|
||||
logger.info(f"请在 {countdown} 秒内扫码登录。")
|
||||
while attempts < max_attempts:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response_data = await response.json()
|
||||
# 成功判断条件和数据提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
if (
|
||||
response_data.get("Data")
|
||||
and response_data["Data"].get("state") is not None
|
||||
):
|
||||
status = response_data["Data"]["state"]
|
||||
logger.info(
|
||||
f"第 {attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}秒"
|
||||
)
|
||||
if status == 2: # 状态 2 表示登录成功
|
||||
self.wxid = response_data["Data"].get("wxid")
|
||||
self.wxnewpass = response_data["Data"].get(
|
||||
"wxnewpass"
|
||||
)
|
||||
logger.info(
|
||||
f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}"
|
||||
)
|
||||
self.save_credentials() # 登录成功后保存凭据
|
||||
return True
|
||||
elif status == -2: # 二维码过期
|
||||
logger.error("二维码已过期,请重新获取。")
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检测登录状态成功但未找到登录状态: {response_data}"
|
||||
)
|
||||
elif response_data.get("Code") == 300:
|
||||
# "不存在状态"
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
f"检测登录状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
await asyncio.sleep(5)
|
||||
attempts += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"检测登录状态时发生错误: {e}")
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
attempts += 1
|
||||
await asyncio.sleep(5) # 每隔5秒检测一次
|
||||
logger.warning("登录检测超过最大尝试次数,退出检测。")
|
||||
return False
|
||||
|
||||
async def connect_websocket(self):
|
||||
"""
|
||||
建立 WebSocket 连接并处理接收到的消息。
|
||||
"""
|
||||
os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}"
|
||||
ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}"
|
||||
logger.info(
|
||||
f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
logger.info("WebSocket 连接成功。")
|
||||
# 设置空闲超时重连
|
||||
wait_time = (
|
||||
self.active_message_poll_interval
|
||||
if self.active_mesasge_poll
|
||||
else 120
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(), timeout=wait_time
|
||||
)
|
||||
# logger.debug(message) # 不显示原始消息内容
|
||||
asyncio.create_task(self.handle_websocket_message(message))
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"WebSocket 连接空闲超过 {wait_time} s")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
logger.info("WebSocket 连接正常关闭。")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
"""
|
||||
处理从 WebSocket 接收到的消息。
|
||||
"""
|
||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||
try:
|
||||
message_data = json.loads(message)
|
||||
if (
|
||||
message_data.get("msg_id") is not None
|
||||
and message_data.get("from_user_name") is not None
|
||||
):
|
||||
abm = await self.convert_message(message_data)
|
||||
if abm:
|
||||
# 创建 WeChatPadProMessageEvent 实例
|
||||
message_event = WeChatPadProMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
# 传递适配器实例,以便在事件中调用 send 方法
|
||||
adapter=self,
|
||||
)
|
||||
# 提交事件到事件队列
|
||||
self.commit_event(message_event)
|
||||
else:
|
||||
logger.warning(f"收到未知结构的 WebSocket 消息: {message_data}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析 WebSocket 消息为 JSON: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
|
||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||
"""
|
||||
将 WeChatPadPro 原始消息转换为 AstrBotMessage。
|
||||
"""
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = raw_message
|
||||
abm.message_id = str(raw_message.get("msg_id"))
|
||||
abm.timestamp = raw_message.get("create_time")
|
||||
abm.self_id = self.wxid
|
||||
|
||||
if int(time.time()) - abm.timestamp > 180:
|
||||
logger.warning(
|
||||
f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}。"
|
||||
)
|
||||
return None
|
||||
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
content = raw_message.get("content", {}).get("str", "")
|
||||
push_content = raw_message.get("push_content", "")
|
||||
msg_type = raw_message.get("msg_type")
|
||||
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
|
||||
# 如果是机器人自己发送的消息、回显消息或系统消息,忽略
|
||||
if from_user_name == self.wxid:
|
||||
logger.info("忽略来自自己的消息。")
|
||||
return None
|
||||
|
||||
if from_user_name in ["weixin", "newsapp", "newsapp_wechat"]:
|
||||
logger.info("忽略来自微信团队的消息。")
|
||||
return None
|
||||
|
||||
# 先判断群聊/私聊并设置基本属性
|
||||
if await self._process_chat_type(
|
||||
abm, raw_message, from_user_name, to_user_name, content, push_content
|
||||
):
|
||||
# 再根据消息类型处理消息内容
|
||||
await self._process_message_content(abm, raw_message, msg_type, content)
|
||||
|
||||
return abm
|
||||
return None
|
||||
|
||||
async def _process_chat_type(
|
||||
self,
|
||||
abm: AstrBotMessage,
|
||||
raw_message: dict,
|
||||
from_user_name: str,
|
||||
to_user_name: str,
|
||||
content: str,
|
||||
push_content: str,
|
||||
):
|
||||
"""
|
||||
判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。
|
||||
"""
|
||||
if from_user_name == "weixin":
|
||||
return False
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = from_user_name
|
||||
|
||||
parts = content.split(":\n", 1)
|
||||
sender_wxid = parts[0] if len(parts) == 2 else ""
|
||||
abm.sender = MessageMember(user_id=sender_wxid, nickname="")
|
||||
|
||||
# 获取群聊发送者的nickname
|
||||
if sender_wxid:
|
||||
accurate_nickname = await self._get_group_member_nickname(
|
||||
abm.group_id, sender_wxid
|
||||
)
|
||||
if accurate_nickname:
|
||||
abm.sender.nickname = accurate_nickname
|
||||
|
||||
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
|
||||
if self.unique_session:
|
||||
abm.session_id = f"{from_user_name}_{to_user_name}"
|
||||
else:
|
||||
abm.session_id = from_user_name
|
||||
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if self.wxid in msg_source:
|
||||
at_me = True
|
||||
if "在群聊中@了你" in raw_message.get("push_content", ""):
|
||||
at_me = True
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=""))
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.group_id = ""
|
||||
nick_name = ""
|
||||
if push_content and " : " in push_content:
|
||||
nick_name = push_content.split(" : ")[0]
|
||||
abm.sender = MessageMember(user_id=from_user_name, nickname=nick_name)
|
||||
abm.session_id = from_user_name
|
||||
return True
|
||||
|
||||
async def _get_group_member_nickname(
|
||||
self, group_id: str, member_wxid: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
通过接口获取群成员的昵称。
|
||||
"""
|
||||
url = f"{self.base_url}/group/GetChatroomMemberDetail"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"ChatRoomName": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 从返回数据中查找对应成员的昵称
|
||||
member_list = (
|
||||
response_data.get("Data", {})
|
||||
.get("member_data", {})
|
||||
.get("chatroom_member_list", [])
|
||||
)
|
||||
for member in member_list:
|
||||
if member.get("user_name") == member_wxid:
|
||||
return member.get("nick_name")
|
||||
logger.warning(
|
||||
f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"获取群成员详情失败: {response.status}, {response_data}"
|
||||
)
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群成员详情时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _download_raw_image(
|
||||
self, from_user_name: str, to_user_name: str, msg_id: int
|
||||
):
|
||||
"""下载原始图片。"""
|
||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"CompressType": 0,
|
||||
"FromUserName": from_user_name,
|
||||
"MsgId": msg_id,
|
||||
"Section": {"DataLen": 61440, "StartPos": 0},
|
||||
"ToUserName": to_user_name,
|
||||
"TotalLen": 0,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(f"下载图片失败: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def download_voice(
|
||||
self, to_user_name: str, new_msg_id: str, bufid: str, length: int
|
||||
):
|
||||
"""下载原始音频。"""
|
||||
url = f"{self.base_url}/message/GetMsgVoice"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"Bufid": bufid,
|
||||
"ToUserName": to_user_name,
|
||||
"NewMsgId": new_msg_id,
|
||||
"Length": length,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.error(f"下载音频失败: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _process_message_content(
|
||||
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
|
||||
):
|
||||
"""
|
||||
根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。
|
||||
"""
|
||||
if msg_type == 1: # 文本消息
|
||||
abm.message_str = content
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
parts = content.split(":\n", 1)
|
||||
if len(parts) == 2:
|
||||
message_content = parts[1]
|
||||
abm.message_str = message_content
|
||||
|
||||
# 检查是否@了机器人,参考 gewechat 的实现方式
|
||||
# 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符(四分之一空格)
|
||||
at_me = False
|
||||
|
||||
# 检查 msg_source 中是否包含机器人的 wxid
|
||||
# wechatpadpro 的格式: <atuserlist>wxid</atuserlist>
|
||||
# gewechat 的格式: <atuserlist><![CDATA[wxid]]></atuserlist>
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if f"<atuserlist>{abm.self_id}</atuserlist>" in msg_source or f"<atuserlist>{abm.self_id}," in msg_source or f",{abm.self_id}</atuserlist>" in msg_source:
|
||||
at_me = True
|
||||
|
||||
# 也检查 push_content 中是否有@提示
|
||||
push_content = raw_message.get("push_content", "")
|
||||
if "在群聊中@了你" in push_content:
|
||||
at_me = True
|
||||
|
||||
if at_me:
|
||||
# 被@了,在消息开头插入At组件(参考gewechat的做法)
|
||||
bot_nickname = await self._get_group_member_nickname(abm.group_id, abm.self_id)
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=bot_nickname or abm.self_id))
|
||||
|
||||
# 只有当消息内容不仅仅是@时才添加Plain组件
|
||||
if "\u2005" in message_content:
|
||||
# 检查@之后是否还有其他内容
|
||||
parts = message_content.split("\u2005")
|
||||
if len(parts) > 1 and any(part.strip() for part in parts[1:]):
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
# 检查是否只包含@机器人
|
||||
is_pure_at = False
|
||||
if bot_nickname and message_content.strip() == f"@{bot_nickname}":
|
||||
is_pure_at = True
|
||||
if not is_pure_at:
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
# 没有@机器人,作为普通文本处理
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
else: # 私聊消息
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
|
||||
# 缓存文本消息,以便引用消息可以查找
|
||||
try:
|
||||
# 获取msg_id作为缓存的key
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id:
|
||||
# 限制缓存大小
|
||||
if (
|
||||
len(self.cached_texts) >= self.max_text_cache
|
||||
and self.cached_texts
|
||||
):
|
||||
# 删除最早的一条缓存
|
||||
oldest_key = next(iter(self.cached_texts))
|
||||
self.cached_texts.pop(oldest_key)
|
||||
|
||||
logger.debug(f"缓存文本消息,new_msg_id={new_msg_id}")
|
||||
self.cached_texts[str(new_msg_id)] = content
|
||||
except Exception as e:
|
||||
logger.error(f"缓存文本消息失败: {e}")
|
||||
elif msg_type == 3:
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
msg_id = raw_message.get("msg_id")
|
||||
image_resp = await self._download_raw_image(
|
||||
from_user_name, to_user_name, msg_id
|
||||
)
|
||||
image_bs64_data = (
|
||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||
)
|
||||
if image_bs64_data:
|
||||
abm.message.append(Image.fromBase64(image_bs64_data))
|
||||
# 缓存图片,以便引用消息可以查找
|
||||
try:
|
||||
# 获取msg_id作为缓存的key
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id:
|
||||
# 限制缓存大小
|
||||
if (
|
||||
len(self.cached_images) >= self.max_image_cache
|
||||
and self.cached_images
|
||||
):
|
||||
# 删除最早的一条缓存
|
||||
oldest_key = next(iter(self.cached_images))
|
||||
self.cached_images.pop(oldest_key)
|
||||
|
||||
logger.debug(f"缓存图片消息,new_msg_id={new_msg_id}")
|
||||
self.cached_images[str(new_msg_id)] = image_bs64_data
|
||||
except Exception as e:
|
||||
logger.error(f"缓存图片消息失败: {e}")
|
||||
elif msg_type == 47:
|
||||
# 视频消息 (注意:表情消息也是 47,需要区分)
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
emoji_message = data_parser.parse_emoji()
|
||||
if emoji_message is not None:
|
||||
abm.message.append(emoji_message)
|
||||
elif msg_type == 50:
|
||||
logger.warning("收到语音/视频消息,待实现。")
|
||||
elif msg_type == 34:
|
||||
# 语音消息
|
||||
bufid = 0
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
|
||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||
bufid = voicemsg.get("bufid") or "0"
|
||||
length = int(voicemsg.get("length") or 0)
|
||||
voice_resp = await self.download_voice(
|
||||
to_user_name=to_user_name,
|
||||
new_msg_id=new_msg_id,
|
||||
bufid=bufid,
|
||||
length=length,
|
||||
)
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_bs64_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
elif msg_type == 49:
|
||||
try:
|
||||
parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
cached_texts=self.cached_texts,
|
||||
cached_images=self.cached_images,
|
||||
raw_message=raw_message,
|
||||
downloader=self._download_raw_image,
|
||||
)
|
||||
components = await parser.parse_mutil_49()
|
||||
if components:
|
||||
abm.message.extend(components)
|
||||
abm.message_str = "\n".join(
|
||||
c.text for c in components if isinstance(c, Plain)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"msg_type 49 处理失败: {e}")
|
||||
abm.message.append(Plain("[XML 消息处理失败]"))
|
||||
abm.message_str = "[XML 消息处理失败]"
|
||||
else:
|
||||
logger.warning(f"收到未处理的消息类型: {msg_type}。")
|
||||
|
||||
async def terminate(self):
|
||||
"""
|
||||
终止一个平台的运行实例。
|
||||
"""
|
||||
logger.info("终止 WeChatPadPro 适配器。")
|
||||
try:
|
||||
if self.ws_handle_task:
|
||||
self.ws_handle_task.cancel()
|
||||
self._shutdown_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""
|
||||
得到一个平台的元数据。
|
||||
"""
|
||||
return self.metadata
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
dummy_message_obj = AstrBotMessage()
|
||||
dummy_message_obj.session_id = session.session_id
|
||||
# 根据 session_id 判断消息类型
|
||||
if "@chatroom" in session.session_id:
|
||||
dummy_message_obj.type = MessageType.GROUP_MESSAGE
|
||||
dummy_message_obj.group_id = session.session_id
|
||||
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||
else:
|
||||
dummy_message_obj.type = MessageType.FRIEND_MESSAGE
|
||||
dummy_message_obj.group_id = ""
|
||||
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||
sending_event = WeChatPadProMessageEvent(
|
||||
message_str="",
|
||||
message_obj=dummy_message_obj,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
adapter=self,
|
||||
)
|
||||
# 调用实例方法 send
|
||||
await sending_event.send(message_chain)
|
||||
|
||||
async def get_contact_list(self):
|
||||
"""
|
||||
获取联系人列表。
|
||||
"""
|
||||
url = f"{self.base_url}/friend/GetContactList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = (
|
||||
result.get("Data", {})
|
||||
.get("ContactList", {})
|
||||
.get("contactUsernameList", [])
|
||||
)
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人列表时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def get_contact_details_list(
|
||||
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
获取联系人详情列表。
|
||||
"""
|
||||
if room_wx_id_list is None:
|
||||
room_wx_id_list = []
|
||||
if user_names is None:
|
||||
user_names = []
|
||||
url = f"{self.base_url}/friend/GetContactDetailsList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人详情列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = result.get("Data", {}).get("contactList", {})
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人详情列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人详情列表时发生错误: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,157 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image as PILImage # 使用别名避免冲突
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import (
|
||||
Image,
|
||||
Plain,
|
||||
WechatEmoji,
|
||||
Record,
|
||||
) # Import Image
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||
|
||||
|
||||
class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
adapter: "WeChatPadProAdapter", # 传递适配器实例
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.message_obj = message_obj # Save the full message object
|
||||
self.adapter = adapter # Save the adapter instance
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for comp in message.chain:
|
||||
await asyncio.sleep(1)
|
||||
if isinstance(comp, Plain):
|
||||
await self._send_text(session, comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
await self._send_image(session, comp)
|
||||
elif isinstance(comp, WechatEmoji):
|
||||
await self._send_emoji(session, comp)
|
||||
elif isinstance(comp, Record):
|
||||
await self._send_voice(session, comp)
|
||||
await super().send(message)
|
||||
|
||||
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||
b64 = await comp.convert_to_base64()
|
||||
raw = self._validate_base64(b64)
|
||||
b64c = self._compress_image(raw)
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendImageNewMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_text(self, session: aiohttp.ClientSession, text: str):
|
||||
if (
|
||||
self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息
|
||||
and self.adapter.settings.get(
|
||||
"reply_with_mention", False
|
||||
) # 检查适配器设置是否启用 reply_with_mention
|
||||
and self.message_obj.sender # 确保有发送者信息
|
||||
and (
|
||||
self.message_obj.sender.user_id or self.message_obj.sender.nickname
|
||||
) # 确保发送者有 ID 或昵称
|
||||
):
|
||||
# 优先使用 nickname,如果没有则使用 user_id
|
||||
mention_text = (
|
||||
self.message_obj.sender.nickname or self.message_obj.sender.user_id
|
||||
)
|
||||
message_text = f"@{mention_text} {text}"
|
||||
# logger.info(f"已添加 @ 信息: {message_text}")
|
||||
else:
|
||||
message_text = text
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{
|
||||
"MsgType": 1,
|
||||
"TextContent": message_text,
|
||||
"ToUserName": self.session_id,
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendTextMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji):
|
||||
payload = {
|
||||
"EmojiList": [
|
||||
{
|
||||
"EmojiMd5": comp.md5,
|
||||
"EmojiSize": comp.md5_len,
|
||||
"ToUserName": self.session_id,
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendEmojiMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 默认已经存在 data/temp 中
|
||||
b64, duration = await wav_to_tencent_silk_base64(record_path)
|
||||
payload = {
|
||||
"ToUserName": self.session_id,
|
||||
"VoiceData": b64,
|
||||
"VoiceFormat": 4,
|
||||
"VoiceSecond": duration,
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendVoice"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
@staticmethod
|
||||
def _validate_base64(b64: str) -> bytes:
|
||||
return base64.b64decode(b64, validate=True)
|
||||
|
||||
@staticmethod
|
||||
def _compress_image(data: bytes) -> str:
|
||||
img = PILImage.open(io.BytesIO(data))
|
||||
buf = io.BytesIO()
|
||||
if img.format == "JPEG":
|
||||
img.save(buf, "JPEG", quality=80)
|
||||
else:
|
||||
if img.mode in ("RGBA", "P"):
|
||||
img = img.convert("RGB")
|
||||
img.save(buf, "JPEG", quality=80)
|
||||
# logger.info("图片处理完成!!!")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
async def _post(self, session, url, payload):
|
||||
params = {"key": self.adapter.auth_key}
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200 or data.get("Code") != 200:
|
||||
logger.error(f"{url} failed: {resp.status} {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"{url} error: {e}")
|
||||
|
||||
|
||||
# TODO: 添加对其他消息组件类型的处理 (Record, Video, At等)
|
||||
# elif isinstance(component, Record):
|
||||
# pass
|
||||
# elif isinstance(component, Video):
|
||||
# pass
|
||||
# elif isinstance(component, At):
|
||||
# pass
|
||||
# ...
|
||||
160
astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
Normal file
160
astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Plain,
|
||||
Image,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
is_private_chat: bool = False,
|
||||
cached_texts=None,
|
||||
cached_images=None,
|
||||
raw_message: dict = None,
|
||||
downloader=None,
|
||||
):
|
||||
self._xml = None
|
||||
self.content = content
|
||||
self.is_private_chat = is_private_chat
|
||||
self.cached_texts = cached_texts or {}
|
||||
self.cached_images = cached_images or {}
|
||||
self.downloader = downloader
|
||||
|
||||
raw_message = raw_message or {}
|
||||
self.from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
self.to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
self.msg_id = raw_message.get("msg_id", "")
|
||||
|
||||
def _format_to_xml(self):
|
||||
if self._xml:
|
||||
return self._xml
|
||||
|
||||
try:
|
||||
msg_str = self.content
|
||||
if not self.is_private_chat:
|
||||
parts = self.content.split(":\n", 1)
|
||||
msg_str = parts[1] if len(parts) == 2 else self.content
|
||||
|
||||
self._xml = eT.fromstring(msg_str)
|
||||
return self._xml
|
||||
except Exception as e:
|
||||
logger.error(f"[XML解析失败] {e}")
|
||||
raise
|
||||
|
||||
async def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
"""
|
||||
处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57)
|
||||
"""
|
||||
try:
|
||||
appmsg_type = self._format_to_xml().findtext(".//appmsg/type")
|
||||
if appmsg_type == "57":
|
||||
return await self.parse_reply()
|
||||
except Exception as e:
|
||||
logger.warning(f"[parse_mutil_49] 解析失败: {e}")
|
||||
return None
|
||||
|
||||
async def parse_reply(self) -> list[BaseMessageComponent]:
|
||||
"""
|
||||
处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49)
|
||||
"""
|
||||
components = []
|
||||
|
||||
try:
|
||||
appmsg = self._format_to_xml().find("appmsg")
|
||||
if appmsg is None:
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
refermsg = appmsg.find("refermsg")
|
||||
if refermsg is None:
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
quote_type = int(refermsg.findtext("type", "0"))
|
||||
nickname = refermsg.findtext("displayname", "未知发送者")
|
||||
quote_content = refermsg.findtext("content", "")
|
||||
svrid = refermsg.findtext("svrid")
|
||||
|
||||
match quote_type:
|
||||
case 1: # 文本引用
|
||||
quoted_text = self.cached_texts.get(str(svrid), quote_content)
|
||||
components.append(Plain(f"[引用] {nickname}: {quoted_text}"))
|
||||
|
||||
case 3: # 图片引用
|
||||
quoted_image_b64 = self.cached_images.get(str(svrid))
|
||||
if not quoted_image_b64:
|
||||
try:
|
||||
quote_xml = eT.fromstring(quote_content)
|
||||
img = quote_xml.find("img")
|
||||
cdn_url = (
|
||||
img.get("cdnbigimgurl") or img.get("cdnmidimgurl")
|
||||
if img is not None
|
||||
else None
|
||||
)
|
||||
if cdn_url and self.downloader:
|
||||
image_resp = await self.downloader(
|
||||
self.from_user_name, self.to_user_name, self.msg_id
|
||||
)
|
||||
quoted_image_b64 = (
|
||||
image_resp.get("Data", {})
|
||||
.get("Data", {})
|
||||
.get("Buffer")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}")
|
||||
|
||||
if quoted_image_b64:
|
||||
components.extend(
|
||||
[
|
||||
Image.fromBase64(quoted_image_b64),
|
||||
Plain(f"[引用] {nickname}: [引用的图片]"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
components.append(
|
||||
Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]")
|
||||
)
|
||||
|
||||
case 49: # 嵌套引用
|
||||
try:
|
||||
nested_root = eT.fromstring(quote_content)
|
||||
nested_title = nested_root.findtext(".//appmsg/title", "")
|
||||
components.append(Plain(f"[引用] {nickname}: {nested_title}"))
|
||||
except Exception as e:
|
||||
logger.warning(f"[嵌套引用解析失败] err={e}")
|
||||
components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]"))
|
||||
|
||||
case _: # 其他未识别类型
|
||||
logger.info(f"[未知引用类型] quote_type={quote_type}")
|
||||
components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]"))
|
||||
|
||||
# 主消息标题
|
||||
title = appmsg.findtext("title", "")
|
||||
if title:
|
||||
components.append(Plain(title))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[parse_reply] 总体解析失败: {e}")
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
return components
|
||||
|
||||
def parse_emoji(self) -> Emoji | None:
|
||||
"""
|
||||
处理 msg_type == 47 的表情消息(emoji)
|
||||
"""
|
||||
try:
|
||||
emoji_element = self._format_to_xml().find(".//emoji")
|
||||
if emoji_element is not None:
|
||||
return Emoji(
|
||||
md5=emoji_element.get("md5"),
|
||||
md5_len=emoji_element.get("len"),
|
||||
cdnurl=emoji_element.get("cdnurl"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[parse_emoji] 解析失败: {e}")
|
||||
|
||||
return None
|
||||
@@ -1,28 +1,33 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import quart
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
from wechatpy.enterprise import WeChatClient, parse_message
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise.messages import ImageMessage, TextMessage, VoiceMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.messages import BaseMessage
|
||||
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Image, Plain, Record
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core import logger
|
||||
from requests import Response
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.enterprise import parse_message
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
from .wecom_kf import WeChatKF
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -131,9 +136,40 @@ class WecomPlatformAdapter(Platform):
|
||||
self.config["corpid"].strip(),
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
# 微信客服
|
||||
self.kf_name = self.config.get("kf_name", None)
|
||||
if self.kf_name:
|
||||
# inject
|
||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
async def callback(msg):
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
|
||||
def get_latest_msg_item() -> dict | None:
|
||||
token = msg._data["Token"]
|
||||
kfid = msg._data["OpenKfId"]
|
||||
has_more = 1
|
||||
ret = {}
|
||||
while has_more:
|
||||
ret = self.wechat_kf_api.sync_msg(token, kfid)
|
||||
has_more = ret["has_more"]
|
||||
msg_list = ret.get("msg_list", [])
|
||||
if msg_list:
|
||||
return msg_list[-1]
|
||||
return None
|
||||
|
||||
msg_new = await asyncio.get_event_loop().run_in_executor(
|
||||
None, get_latest_msg_item
|
||||
)
|
||||
if msg_new:
|
||||
await self.convert_wechat_kf_message(msg_new)
|
||||
return
|
||||
await self.convert_message(msg)
|
||||
|
||||
self.server.callback = callback
|
||||
@@ -153,9 +189,39 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
if self.kf_name:
|
||||
try:
|
||||
acc_list = (
|
||||
await loop.run_in_executor(
|
||||
None, self.wechat_kf_api.get_account_list
|
||||
)
|
||||
).get("account_list", [])
|
||||
logger.debug(f"获取到微信客服列表: {str(acc_list)}")
|
||||
for acc in acc_list:
|
||||
name = acc.get("name", None)
|
||||
if name != self.kf_name:
|
||||
continue
|
||||
open_kfid = acc.get("open_kfid", None)
|
||||
if not open_kfid:
|
||||
logger.error("获取微信客服失败,open_kfid 为空。")
|
||||
logger.debug(f"Found open_kfid: {str(open_kfid)}")
|
||||
kf_url = (
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.wechat_kf_api.add_contact_way,
|
||||
open_kfid,
|
||||
"astrbot_placeholder",
|
||||
)
|
||||
).get("url", "")
|
||||
logger.info(
|
||||
f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(self, msg):
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if msg.type == "text":
|
||||
assert isinstance(msg, TextMessage)
|
||||
@@ -191,14 +257,15 @@ class WecomPlatformAdapter(Platform):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, msg.media_id
|
||||
)
|
||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
@@ -218,10 +285,43 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
return
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||
msgtype = msg.get("msgtype", None)
|
||||
external_userid = msg.get("external_userid", None)
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
abm.self_id = msg["open_kfid"]
|
||||
abm.sender = MessageMember(external_userid, external_userid)
|
||||
abm.session_id = external_userid
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
|
||||
if msgtype == "text":
|
||||
text = msg.get("text", {}).get("content", "").strip()
|
||||
abm.message = [Plain(text=text)]
|
||||
abm.message_str = text
|
||||
elif msgtype == "image":
|
||||
media_id = msg.get("image", {}).get("media_id", "")
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, media_id
|
||||
)
|
||||
path = f"data/temp/wechat_kf_{media_id}.jpg"
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
abm.message = [Image(file=path, url=path)]
|
||||
abm.message_str = "[图片]"
|
||||
else:
|
||||
logger.warning(f"未实现的微信客服消息事件: {msg}")
|
||||
return
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = WecomPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import pydub
|
||||
@@ -52,19 +55,29 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
if start + 2048 >= len(plain):
|
||||
result.append(plain[start:])
|
||||
break
|
||||
|
||||
|
||||
# 向前搜索分割标点符号
|
||||
end = min(start + 2048, len(plain))
|
||||
cut_position = end
|
||||
for i in range(end, start, -1):
|
||||
if i < len(plain) and plain[i-1] in ["。", "!", "?", ".", "!", "?", "\n", ";", ";"]:
|
||||
if i < len(plain) and plain[i - 1] in [
|
||||
"。",
|
||||
"!",
|
||||
"?",
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
"\n",
|
||||
";",
|
||||
";",
|
||||
]:
|
||||
cut_position = i
|
||||
break
|
||||
|
||||
|
||||
# 没找到合适的位置分割, 直接切分
|
||||
if cut_position == end and end < len(plain):
|
||||
cut_position = end
|
||||
|
||||
|
||||
result.append(plain[start:cut_position])
|
||||
start = cut_position
|
||||
|
||||
@@ -73,57 +86,98 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
self.client.message.send_text(
|
||||
message_obj.self_id, message_obj.session_id, chunk
|
||||
)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
is_wechat_kf = hasattr(self.client, "kf_message")
|
||||
if is_wechat_kf:
|
||||
# 微信客服
|
||||
kf_message_api = getattr(self.client, "kf_message", None)
|
||||
if not kf_message_api:
|
||||
logger.warning("未找到微信客服发送消息方法。")
|
||||
return
|
||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||
user_id = self.get_sender_id()
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传图片返回: {response}")
|
||||
kf_message_api.send_image(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
else:
|
||||
# 企业微信应用
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
self.client.message.send_text(
|
||||
message_obj.self_id, message_obj.session_id, chunk
|
||||
)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
|
||||
278
astrbot/core/platform/sources/wecom/wecom_kf.py
Normal file
278
astrbot/core/platform/sources/wecom/wecom_kf.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014-2020 messense
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
|
||||
class WeChatKF(BaseWeChatAPI):
|
||||
"""
|
||||
微信客服接口
|
||||
|
||||
https://work.weixin.qq.com/api/doc/90000/90135/94670
|
||||
"""
|
||||
|
||||
def sync_msg(self, token, open_kfid, cursor="", limit=1000):
|
||||
"""
|
||||
微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收)
|
||||
、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。
|
||||
支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。
|
||||
|
||||
|
||||
:param token: 回调事件返回的token字段,10分钟内有效;可不填,如果不填接口有严格的频率限制。不多于128字节
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param cursor: 上一次调用时返回的next_cursor,第一次拉取可以不填。不多于64字节
|
||||
:param limit: 期望请求的数据量,默认值和最大值都为1000。
|
||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||
return self._post("kf/sync_msg", data=data)
|
||||
|
||||
def get_service_state(self, open_kfid, external_userid):
|
||||
"""
|
||||
获取会话状态
|
||||
|
||||
ID 状态 说明
|
||||
0 未处理 新会话接入。可选择:1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待
|
||||
1 由智能助手接待 可使用API回复消息。可选择转入待接入池或者指定接待人员处理。
|
||||
2 待接入池排队中 在待接入池中排队等待接待人员接入。可选择转为指定人员接待
|
||||
3 由人工接待 人工接待中。可选择结束会话
|
||||
4 已结束 会话已经结束。不允许变更会话状态,等待用户重新发起咨询
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
}
|
||||
return self._post("kf/service_state/get", data=data)
|
||||
|
||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||
"""
|
||||
变更会话状态
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:param service_state: 当前的会话状态,状态定义参考概述中的表格
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
"service_state": service_state,
|
||||
}
|
||||
if servicer_userid:
|
||||
data["servicer_userid"] = servicer_userid
|
||||
return self._post("kf/service_state/trans", data=data)
|
||||
|
||||
def get_servicer_list(self, open_kfid):
|
||||
"""
|
||||
获取接待人员列表
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
return self._get("kf/servicer/list", params=data)
|
||||
|
||||
def add_servicer(self, open_kfid, userid_list):
|
||||
"""
|
||||
添加接待人员
|
||||
添加指定客服帐号的接待人员。
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param userid_list: 接待人员userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(userid_list, list):
|
||||
userid_list = [userid_list]
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"userid_list": userid_list,
|
||||
}
|
||||
return self._post("kf/servicer/add", data=data)
|
||||
|
||||
def del_servicer(self, open_kfid, userid_list):
|
||||
"""
|
||||
删除接待人员
|
||||
从客服帐号删除接待人员
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param userid_list: 接待人员userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(userid_list, list):
|
||||
userid_list = [userid_list]
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"userid_list": userid_list,
|
||||
}
|
||||
return self._post("kf/servicer/del", data=data)
|
||||
|
||||
def batchget_customer(self, external_userid_list):
|
||||
"""
|
||||
客户基本信息获取
|
||||
|
||||
:param external_userid_list: external_userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(external_userid_list, list):
|
||||
external_userid_list = [external_userid_list]
|
||||
|
||||
data = {
|
||||
"external_userid_list": external_userid_list,
|
||||
}
|
||||
return self._post("kf/customer/batchget", data=data)
|
||||
|
||||
def get_account_list(self):
|
||||
"""
|
||||
获取客服帐号列表
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
return self._get("kf/account/list")
|
||||
|
||||
def add_contact_way(self, open_kfid, scene):
|
||||
"""
|
||||
获取客服帐号链接
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param scene: 场景值,字符串类型,由开发者自定义。不多于32字节;字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "scene": scene}
|
||||
return self._post("kf/add_contact_way", data=data)
|
||||
|
||||
def get_upgrade_service_config(self):
|
||||
"""
|
||||
获取配置的专员与客户群
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
return self._get("kf/customer/get_upgrade_service_config")
|
||||
|
||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||
"""
|
||||
为客户升级为专员或客户群服务
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:param service_type: 表示是升级到专员服务还是客户群服务。1:专员服务。2:客户群服务
|
||||
:param member: 推荐的服务专员,type等于1时有效
|
||||
:param groupchat: 推荐的客户群,type等于2时有效
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
"type": service_type,
|
||||
}
|
||||
if service_type == 1:
|
||||
data["member"] = member
|
||||
else:
|
||||
data["groupchat"] = groupchat
|
||||
return self._post("kf/customer/upgrade_service", data=data)
|
||||
|
||||
def cancel_upgrade_service(self, open_kfid, external_userid):
|
||||
"""
|
||||
为客户取消推荐
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {"open_kfid": open_kfid, "external_userid": external_userid}
|
||||
return self._post("kf/customer/cancel_upgrade_service", data=data)
|
||||
|
||||
def send_msg_on_event(self, code, msgtype, msg_content, msgid=None):
|
||||
"""
|
||||
当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。
|
||||
支持发送消息类型:文本、菜单消息。
|
||||
|
||||
:param code: 事件响应消息对应的code。通过事件回调下发,仅可使用一次。
|
||||
:param msgtype: 消息类型。对不同的msgtype,有相应的结构描述,详见消息类型
|
||||
:param msg_content: 目前支持文本与菜单消息,具体查看文档
|
||||
:param msgid: 消息ID。如果请求参数指定了msgid,则原样返回,否则系统自动生成并返回。不多于32字节;
|
||||
字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {"code": code, "msgtype": msgtype}
|
||||
if msgid:
|
||||
data["msgid"] = msgid
|
||||
data.update(msg_content)
|
||||
return self._post("kf/send_msg_on_event", data=data)
|
||||
|
||||
def get_corp_statistic(self, start_time, end_time, open_kfid=None):
|
||||
"""
|
||||
获取「客户数据统计」企业汇总数据
|
||||
|
||||
:param start_time: 开始时间
|
||||
:param end_time: 结束时间
|
||||
:param open_kfid: 客服帐号ID
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||
return self._post("kf/get_corp_statistic", data=data)
|
||||
|
||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||
"""
|
||||
获取「客户数据统计」接待人员明细数据
|
||||
|
||||
:param start_time: 开始时间
|
||||
:param end_time: 结束时间
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param servicer_userid: 接待人员
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"servicer_userid": servicer_userid,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
return self._post("kf/get_servicer_statistic", data=data)
|
||||
|
||||
def account_update(self, open_kfid, name, media_id):
|
||||
"""
|
||||
修改客服账号
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param name: 客服名称
|
||||
:param media_id: 客服头像临时素材
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "name": name, "media_id": media_id}
|
||||
return self._post("kf/account/update", data=data)
|
||||
159
astrbot/core/platform/sources/wecom/wecom_kf_message.py
Normal file
159
astrbot/core/platform/sources/wecom/wecom_kf_message.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014-2020 messense
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from optionaldict import optionaldict
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
class WeChatKFMessage(BaseWeChatAPI):
|
||||
"""
|
||||
发送微信客服消息
|
||||
|
||||
https://work.weixin.qq.com/api/doc/90000/90135/94677
|
||||
|
||||
支持:
|
||||
* 文本消息
|
||||
* 图片消息
|
||||
* 语音消息
|
||||
* 视频消息
|
||||
* 文件消息
|
||||
* 图文链接
|
||||
* 小程序
|
||||
* 菜单消息
|
||||
* 地理位置
|
||||
"""
|
||||
|
||||
def send(self, user_id, open_kfid, msgid="", msg=None):
|
||||
"""
|
||||
当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。
|
||||
注意仅当微信客户在主动发送消息给客服后的48小时内,企业可发送消息给客户,最多可发送5条消息;若用户继续发送消息,企业可再次下发消息。
|
||||
支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。
|
||||
|
||||
:param user_id: 指定接收消息的客户UserID
|
||||
:param open_kfid: 指定发送消息的客服帐号ID
|
||||
:param msgid: 指定消息ID
|
||||
:param tag_ids: 标签ID列表。
|
||||
:param msg: 发送消息的 dict 对象
|
||||
:type msg: dict | None
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
msg = msg or {}
|
||||
data = {
|
||||
"touser": user_id,
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
if msgid:
|
||||
data["msgid"] = msgid
|
||||
data.update(msg)
|
||||
return self._post("kf/send_msg", data=data)
|
||||
|
||||
def send_text(self, user_id, open_kfid, content, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "text", "text": {"content": content}},
|
||||
)
|
||||
|
||||
def send_image(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "image", "image": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_voice(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "voice", "voice": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_video(self, user_id, open_kfid, media_id, msgid=""):
|
||||
video_data = optionaldict()
|
||||
video_data["media_id"] = media_id
|
||||
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "video", "video": dict(video_data)},
|
||||
)
|
||||
|
||||
def send_file(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "file", "file": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_articles_link(self, user_id, open_kfid, article, msgid=""):
|
||||
articles_data = {
|
||||
"title": article["title"],
|
||||
"desc": article["desc"],
|
||||
"url": article["url"],
|
||||
"thumb_media_id": article["thumb_media_id"],
|
||||
}
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||
)
|
||||
|
||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "msgmenu",
|
||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||
},
|
||||
)
|
||||
|
||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "location",
|
||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||
},
|
||||
)
|
||||
|
||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "miniprogram",
|
||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,286 @@
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import quart
|
||||
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core import logger
|
||||
from requests import Response
|
||||
|
||||
from wechatpy.utils import check_signature
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy import parse_message
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.token = config.get("token")
|
||||
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||
self.appid = config.get("appid")
|
||||
self.server.add_url_rule(
|
||||
"/callback/command", view_func=self.verify, methods=["GET"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/callback/command", view_func=self.callback_command, methods=["POST"]
|
||||
)
|
||||
self.crypto = WeChatCrypto(self.token, self.encoding_aes_key, self.appid)
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.callback = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||
|
||||
args = quart.request.args
|
||||
if not args.get("signature", None):
|
||||
logger.error("未知的响应,请检查回调地址是否填写正确。")
|
||||
return "err"
|
||||
try:
|
||||
check_signature(
|
||||
self.token,
|
||||
args.get("signature"),
|
||||
args.get("timestamp"),
|
||||
args.get("nonce"),
|
||||
)
|
||||
logger.info("验证请求有效性成功。")
|
||||
return args.get("echostr", "empty")
|
||||
except InvalidSignatureException:
|
||||
logger.error("验证请求有效性失败,签名异常,请检查配置。")
|
||||
return "err"
|
||||
|
||||
async def callback_command(self):
|
||||
data = await quart.request.get_data()
|
||||
msg_signature = quart.request.args.get("msg_signature")
|
||||
timestamp = quart.request.args.get("timestamp")
|
||||
nonce = quart.request.args.get("nonce")
|
||||
try:
|
||||
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||
except InvalidSignatureException:
|
||||
logger.error("解密失败,签名异常,请检查配置。")
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
result_xml = await self.callback(msg)
|
||||
if not result_xml:
|
||||
return "success"
|
||||
if isinstance(result_xml, str):
|
||||
return result_xml
|
||||
|
||||
return "success"
|
||||
|
||||
async def start_polling(self):
|
||||
logger.info(
|
||||
f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。"
|
||||
)
|
||||
await self.server.run_task(
|
||||
host=self.callback_server_host,
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
|
||||
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
|
||||
class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
self.api_base_url = platform_config.get(
|
||||
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
|
||||
)
|
||||
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||
|
||||
if self.api_base_url.endswith("/"):
|
||||
self.api_base_url = self.api_base_url[:-1]
|
||||
if not self.api_base_url.endswith("/cgi-bin"):
|
||||
self.api_base_url += "/cgi-bin"
|
||||
|
||||
if not self.api_base_url.endswith("/"):
|
||||
self.api_base_url += "/"
|
||||
|
||||
self.server = WecomServer(self._event_queue, self.config)
|
||||
|
||||
self.client = WeChatClient(
|
||||
self.config["appid"].strip(),
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
self.wexin_event_workers: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
try:
|
||||
if self.active_send_mode:
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息时出现异常: {e}")
|
||||
|
||||
self.server.callback = callback
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(
|
||||
self, msg, future: asyncio.Future = None
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, msg.media_id
|
||||
)
|
||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
|
||||
)
|
||||
path_wav = path
|
||||
return
|
||||
|
||||
abm.message_str = ""
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
future.set_result(None)
|
||||
return
|
||||
# 很不优雅 :(
|
||||
abm.raw_message = {
|
||||
"message": msg,
|
||||
"future": future,
|
||||
"active_send_mode": self.active_send_mode,
|
||||
}
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = WeixinOfficialAccountPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self) -> WeChatClient:
|
||||
return self.client
|
||||
|
||||
async def terminate(self):
|
||||
self.server.shutdown_event.set()
|
||||
try:
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("微信公众平台 适配器已被优雅地关闭")
|
||||
@@ -0,0 +1,185 @@
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import TextReply, ImageReply, VoiceReply
|
||||
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: WeChatClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(
|
||||
client: WeChatClient, message: MessageChain, user_name: str
|
||||
):
|
||||
pass
|
||||
|
||||
async def split_plain(self, plain: str) -> list[str]:
|
||||
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||
|
||||
Args:
|
||||
plain (str): 要分割的长文本
|
||||
Returns:
|
||||
list[str]: 分割后的文本列表
|
||||
"""
|
||||
if len(plain) <= 2048:
|
||||
return [plain]
|
||||
else:
|
||||
result = []
|
||||
start = 0
|
||||
while start < len(plain):
|
||||
# 剩下的字符串长度<2048时结束
|
||||
if start + 2048 >= len(plain):
|
||||
result.append(plain[start:])
|
||||
break
|
||||
|
||||
# 向前搜索分割标点符号
|
||||
end = min(start + 2048, len(plain))
|
||||
cut_position = end
|
||||
for i in range(end, start, -1):
|
||||
if i < len(plain) and plain[i - 1] in [
|
||||
"。",
|
||||
"!",
|
||||
"?",
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
"\n",
|
||||
";",
|
||||
";",
|
||||
]:
|
||||
cut_position = i
|
||||
break
|
||||
|
||||
# 没找到合适的位置分割, 直接切分
|
||||
if cut_position == end and end < len(plain):
|
||||
cut_position = end
|
||||
|
||||
result.append(plain[start:cut_position])
|
||||
start = cut_position
|
||||
|
||||
return result
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
if active_send_mode:
|
||||
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信公众平台上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信公众平台上传语音失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -155,7 +156,9 @@ class ProviderRequest:
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
|
||||
@@ -3,8 +3,8 @@ import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
@@ -13,12 +13,21 @@ from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
@@ -95,7 +104,10 @@ class MCPClient:
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
@@ -107,15 +119,41 @@ class MCPClient:
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
|
||||
if "url" in cfg:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(url=cfg["url"])
|
||||
streams = await self._streams_context.__aenter__()
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
# self.session = await self._session_context.__aenter__()
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
@@ -239,8 +277,7 @@ class FuncCall:
|
||||
}
|
||||
```
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||
if not os.path.exists(mcp_json_file):
|
||||
@@ -360,7 +397,7 @@ class FuncCall:
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return True
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
@@ -369,7 +406,7 @@ class FuncCall:
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return False
|
||||
return
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -441,28 +478,51 @@ class FuncCall:
|
||||
"""
|
||||
|
||||
# Gemini API 支持的数据类型和格式
|
||||
supported_types = {"string", "number", "integer", "boolean", "array", "object", "null"}
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"}
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""转换 schema 为 Gemini API 格式"""
|
||||
|
||||
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if ("format" in schema and
|
||||
schema["format"] in supported_formats.get(result["type"], set())):
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
# 暂时指定默认为null
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {"title", "description", "enum", "minimum", "maximum",
|
||||
"maxItems", "minItems", "nullable", "required"}
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
@@ -478,8 +538,6 @@ class FuncCall:
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
if "anyOf" in schema:
|
||||
result["anyOf"] = [convert_schema(s) for s in schema["anyOf"]]
|
||||
|
||||
return result
|
||||
|
||||
@@ -487,9 +545,10 @@ class FuncCall:
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
**({"parameters": convert_schema(f.parameters)})
|
||||
**({"parameters": convert_schema(f.parameters)}),
|
||||
}
|
||||
for f in self.func_list if f.active
|
||||
for f in self.func_list
|
||||
if f.active
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
|
||||
@@ -18,13 +18,6 @@ class ProviderManager:
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
self.selected_provider_id = sp.get("curr_provider")
|
||||
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
self.provider_enabled = self.provider_settings.get("enable", False)
|
||||
self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
@@ -98,15 +91,18 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
self.curr_provider_inst: Provider = None
|
||||
"""当前使用的 Provider 实例"""
|
||||
"""默认的 Provider 实例"""
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
"""当前使用的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
"""当前使用的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
@@ -115,18 +111,57 @@ class ProviderManager:
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo and self.provider_settings["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
# 设置默认提供商
|
||||
self.curr_provider_inst = self.inst_map.get(
|
||||
self.provider_settings.get("default_provider_id")
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
if self.stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
self.curr_stt_provider_inst = self.inst_map.get(
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
self.curr_tts_provider_inst = self.inst_map.get(
|
||||
self.provider_tts_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
@@ -202,6 +237,26 @@ class ProviderManager:
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -234,14 +289,14 @@ class ProviderManager:
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_stt_provider_id == provider_config["id"]
|
||||
and self.stt_enabled
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_enabled:
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
@@ -254,15 +309,12 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_tts_provider_id == provider_config["id"]
|
||||
and self.tts_enabled
|
||||
):
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_enabled:
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
@@ -280,16 +332,24 @@ class ProviderManager:
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_provider_id == provider_config["id"]
|
||||
and self.provider_enabled
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_enabled:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -310,39 +370,24 @@ class ProviderManager:
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif (
|
||||
self.curr_provider_inst is None
|
||||
and len(self.provider_insts) > 0
|
||||
and self.provider_enabled
|
||||
):
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None
|
||||
and len(self.stt_provider_insts) > 0
|
||||
and self.stt_enabled
|
||||
):
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None
|
||||
and len(self.tts_provider_insts) > 0
|
||||
and self.tts_enabled
|
||||
):
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
|
||||
@@ -179,3 +179,25 @@ class TTSProvider(AbstractProvider):
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
@@ -104,11 +104,13 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
if not prompt:
|
||||
prompt = "<image>"
|
||||
|
||||
|
||||
210
astrbot/core/provider/sources/azure_tts_source.py
Normal file
210
astrbot/core/provider/sources/azure_tts_source.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import hashlib
|
||||
import random
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from httpx import AsyncClient, Timeout
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
TEMP_DIR = Path("data/temp/azure_tts")
|
||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
class OTTSProvider:
|
||||
def __init__(self, config: Dict):
|
||||
self.skey = config["OTTS_SKEY"]
|
||||
self.api_url = config["OTTS_URL"]
|
||||
self.auth_time_url = config["OTTS_AUTH_TIME"]
|
||||
self.time_offset = 0
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.client = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(timeout=self.timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
async def _sync_time(self):
|
||||
try:
|
||||
response = await self.client.get(self.auth_time_url)
|
||||
response.raise_for_status()
|
||||
server_time = int(response.json()["timestamp"])
|
||||
local_time = int(time.time())
|
||||
self.time_offset = server_time - local_time
|
||||
self.last_sync_time = local_time
|
||||
except Exception as e:
|
||||
if time.time() - self.last_sync_time > 3600:
|
||||
raise RuntimeError("时间同步失败") from e
|
||||
|
||||
async def _generate_signature(self) -> str:
|
||||
await self._sync_time()
|
||||
timestamp = int(time.time()) + self.time_offset
|
||||
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
|
||||
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
||||
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
||||
|
||||
async def get_audio(self, text: str, voice_params: Dict) -> str:
|
||||
file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav"
|
||||
signature = await self._generate_signature()
|
||||
for attempt in range(self.retry_count):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.api_url}?sign={signature}",
|
||||
data={
|
||||
"text": text,
|
||||
"voice": voice_params["voice"],
|
||||
"style": voice_params["style"],
|
||||
"role": voice_params["role"],
|
||||
"rate": voice_params["rate"],
|
||||
"volume": voice_params["volume"]
|
||||
},
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"UAK": "AstrBot/AzureTTS"
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
async for chunk in response.aiter_bytes(4096):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
except Exception as e:
|
||||
if attempt == self.retry_count - 1:
|
||||
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
|
||||
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||
raise ValueError("无效的Azure订阅密钥")
|
||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
self.client = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
self.voice_params = {
|
||||
"voice": provider_config.get("azure_tts_voice", "zh-CN-YunxiaNeural"),
|
||||
"style": provider_config.get("azure_tts_style", "cheerful"),
|
||||
"role": provider_config.get("azure_tts_role", "Boy"),
|
||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||
"volume": provider_config.get("azure_tts_volume", "100")
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
|
||||
})
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
||||
response = await self.client.post(
|
||||
token_url,
|
||||
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.token = response.text
|
||||
self.token_expire = time.time() + 540
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
if not self.token or time.time() > self.token_expire:
|
||||
await self._refresh_token()
|
||||
file_path = TEMP_DIR / f"azure-{uuid.uuid4()}.wav"
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis'
|
||||
xmlns:mstts='http://www.w3.org/2001/mstts' xml:lang='zh-CN'>
|
||||
<voice name='{escape(self.voice_params["voice"])}'>
|
||||
<mstts:express-as style='{escape(self.voice_params["style"])}'
|
||||
role='{escape(self.voice_params["role"])}'>
|
||||
<prosody rate='{escape(self.voice_params["rate"])}'
|
||||
volume='{escape(self.voice_params["volume"])}'>
|
||||
{escape(text)}
|
||||
</prosody>
|
||||
</mstts:express-as>
|
||||
</voice>
|
||||
</speak>"""
|
||||
response = await self.client.post(
|
||||
self.endpoint,
|
||||
content=ssml,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"User-Agent": f"AstrBot/{VERSION}"
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
for chunk in response.iter_bytes(4096):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
|
||||
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
||||
class AzureTTSProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||
self.provider = self._parse_provider(key_value, provider_config)
|
||||
|
||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
||||
if key_value.lower().startswith("other["):
|
||||
try:
|
||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError("无效的other[...]格式,应形如 other[{...}]")
|
||||
json_str = match.group(1).strip()
|
||||
otts_config = json.loads(json_str)
|
||||
required = {"OTTS_SKEY", "OTTS_URL", "OTTS_AUTH_TIME"}
|
||||
if missing := required - otts_config.keys():
|
||||
raise ValueError(f"缺少OTTS参数: {', '.join(missing)}")
|
||||
return OTTSProvider(otts_config)
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = (
|
||||
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
||||
f"错误详情: {e.msg}\n"
|
||||
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except KeyError as e:
|
||||
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
||||
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
|
||||
return AzureNativeProvider(config, self.provider_settings)
|
||||
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
if isinstance(self.provider, OTTSProvider):
|
||||
async with self.provider as provider:
|
||||
return await provider.get_audio(
|
||||
text,
|
||||
{
|
||||
"voice": self.provider_config.get("azure_tts_voice"),
|
||||
"style": self.provider_config.get("azure_tts_style"),
|
||||
"role": self.provider_config.get("azure_tts_role"),
|
||||
"rate": self.provider_config.get("azure_tts_rate"),
|
||||
"volume": self.provider_config.get("azure_tts_volume")
|
||||
}
|
||||
)
|
||||
else:
|
||||
async with self.provider as provider:
|
||||
return await provider.get_audio(text)
|
||||
@@ -74,6 +74,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
@@ -5,6 +6,7 @@ from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -24,7 +26,8 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
voice=self.voice,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
@@ -60,12 +61,14 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
result = ""
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
@@ -227,7 +230,8 @@ class ProviderDify(Provider):
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
path = f"data/temp/{item['filename']}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
"""
|
||||
edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库
|
||||
@@ -40,9 +41,9 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
@@ -6,6 +7,7 @@ from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
@@ -87,7 +89,8 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
|
||||
63
astrbot/core/provider/sources/gemini_embedding_source.py
Normal file
63
astrbot/core/provider/sources/gemini_embedding_source.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gemini_embedding",
|
||||
"Google Gemini Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
api_key: str = provider_config.get("embedding_api_key")
|
||||
api_base: str = provider_config.get("embedding_api_base", None)
|
||||
timeout: int = int(provider_config.get("timeout", 20))
|
||||
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
if api_base:
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base[:-1]
|
||||
http_options.base_url = api_base
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
|
||||
self.model = provider_config.get(
|
||||
"embedding_model", "gemini-embedding-exp-03-07"
|
||||
)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 768)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model, contents=text
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model, contents=texts
|
||||
)
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -3,7 +3,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
@@ -15,7 +15,7 @@ from astrbot import logger
|
||||
from astrbot.api.provider import Personality, Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -65,7 +65,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.api_keys: list = provider_config.get("key", [])
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||
|
||||
@@ -99,7 +99,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
and threshold_str in self.THRESHOLD_MAPPING
|
||||
]
|
||||
|
||||
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
|
||||
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
|
||||
"""处理API错误,返回是否需要重试"""
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
@@ -126,7 +126,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
payloads: dict,
|
||||
tools: Optional[FuncCall] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
modalities: Optional[List[str]] = None,
|
||||
modalities: Optional[list[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
@@ -141,52 +141,111 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
|
||||
tool_list = None
|
||||
tool_list = []
|
||||
model_name = self.get_model()
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
|
||||
if native_coderunner:
|
||||
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
|
||||
if native_search:
|
||||
logger.warning("已启用代码执行工具,搜索工具将被忽略")
|
||||
if tools:
|
||||
logger.warning("已启用代码执行工具,函数工具将被忽略")
|
||||
elif native_search:
|
||||
tool_list = [types.Tool(google_search=types.GoogleSearch())]
|
||||
if tools:
|
||||
logger.warning("已启用搜索工具,函数工具将被忽略")
|
||||
if "gemini-2.5" in model_name:
|
||||
if native_coderunner:
|
||||
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
||||
if native_search:
|
||||
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
||||
if url_context:
|
||||
logger.warning(
|
||||
"代码执行工具与URL上下文工具互斥,已忽略URL上下文工具"
|
||||
)
|
||||
else:
|
||||
if native_search:
|
||||
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
|
||||
if url_context:
|
||||
if hasattr(types, "UrlContext"):
|
||||
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
||||
else:
|
||||
logger.warning(
|
||||
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
|
||||
)
|
||||
|
||||
elif "gemini-2.0-lite" in model_name:
|
||||
if native_coderunner or native_search or url_context:
|
||||
logger.warning(
|
||||
"gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文,将忽略这些设置"
|
||||
)
|
||||
tool_list = None
|
||||
|
||||
else:
|
||||
if native_coderunner:
|
||||
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
||||
if native_search:
|
||||
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
||||
elif native_search:
|
||||
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
|
||||
if url_context and not native_coderunner:
|
||||
if hasattr(types, "UrlContext"):
|
||||
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
||||
else:
|
||||
logger.warning(
|
||||
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
|
||||
)
|
||||
|
||||
if not tool_list:
|
||||
tool_list = None
|
||||
|
||||
if tools and tool_list:
|
||||
logger.warning("已启用原生工具,函数工具将被忽略")
|
||||
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"),
|
||||
max_output_tokens=payloads.get("max_tokens")
|
||||
or payloads.get("maxOutputTokens"),
|
||||
top_p=payloads.get("top_p") or payloads.get("topP"),
|
||||
top_k=payloads.get("top_k") or payloads.get("topK"),
|
||||
frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"),
|
||||
presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"),
|
||||
frequency_penalty=payloads.get("frequency_penalty")
|
||||
or payloads.get("frequencyPenalty"),
|
||||
presence_penalty=payloads.get("presence_penalty")
|
||||
or payloads.get("presencePenalty"),
|
||||
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
|
||||
response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"),
|
||||
response_logprobs=payloads.get("response_logprobs")
|
||||
or payloads.get("responseLogprobs"),
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
)
|
||||
|
||||
def _prepare_conversation(self, payloads: Dict) -> List[types.Content]:
|
||||
def _prepare_conversation(self, payloads: dict) -> list[types.Content]:
|
||||
"""准备 Gemini SDK 的 Content 列表"""
|
||||
|
||||
def create_text_part(text: str) -> types.UserContent:
|
||||
def create_text_part(text: str) -> types.Part:
|
||||
content_a = text if text else " "
|
||||
if not text:
|
||||
logger.warning("文本内容为空,已添加空格占位")
|
||||
return types.UserContent(parts=[types.Part.from_text(text=content_a)])
|
||||
return types.Part.from_text(text=content_a)
|
||||
|
||||
def process_image_url(image_url_dict: dict) -> types.Part:
|
||||
url = image_url_dict["url"]
|
||||
@@ -194,7 +253,17 @@ class ProviderGoogleGenAI(Provider):
|
||||
image_bytes = base64.b64decode(url.split(",", 1)[1])
|
||||
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
||||
|
||||
gemini_contents: List[types.Content] = []
|
||||
def append_or_extend(
|
||||
contents: list[types.Content],
|
||||
part: list[types.Part],
|
||||
content_cls: type[types.Content],
|
||||
) -> None:
|
||||
if contents and isinstance(contents[-1], content_cls):
|
||||
contents[-1].parts.extend(part)
|
||||
else:
|
||||
contents.append(content_cls(parts=part))
|
||||
|
||||
gemini_contents: list[types.Content] = []
|
||||
native_tool_enabled = any(
|
||||
[
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
@@ -205,60 +274,53 @@ class ProviderGoogleGenAI(Provider):
|
||||
role, content = message["role"], message.get("content")
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, str):
|
||||
gemini_contents.append(create_text_part(content))
|
||||
elif isinstance(content, list):
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
types.Part.from_text(text=item["text"] or " ")
|
||||
if item["type"] == "text"
|
||||
else process_image_url(item["image_url"])
|
||||
for item in content
|
||||
]
|
||||
gemini_contents.append(types.UserContent(parts=parts))
|
||||
else:
|
||||
parts = [create_text_part(content)]
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
gemini_contents.append(
|
||||
types.ModelContent(parts=[types.Part.from_text(text=content)])
|
||||
)
|
||||
elif "tool_calls" in message and not native_tool_enabled:
|
||||
gemini_contents.extend(
|
||||
[
|
||||
types.ModelContent(
|
||||
parts=[
|
||||
types.Part.from_function_call(
|
||||
name=tool["function"]["name"],
|
||||
args=json.loads(tool["function"]["arguments"]),
|
||||
)
|
||||
]
|
||||
)
|
||||
for tool in message["tool_calls"]
|
||||
]
|
||||
)
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
elif not native_tool_enabled and "tool_calls" in message:
|
||||
parts = [
|
||||
types.Part.from_function_call(
|
||||
name=tool["function"]["name"],
|
||||
args=json.loads(tool["function"]["arguments"]),
|
||||
)
|
||||
for tool in message["tool_calls"]
|
||||
]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
else:
|
||||
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
||||
if native_tool_enabled:
|
||||
if native_tool_enabled and "tool_calls" in message:
|
||||
logger.warning(
|
||||
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文"
|
||||
)
|
||||
gemini_contents.append(
|
||||
types.ModelContent(parts=[types.Part.from_text(text=" ")])
|
||||
)
|
||||
parts = [types.Part.from_text(text=" ")]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
|
||||
elif role == "tool" and not native_tool_enabled:
|
||||
gemini_contents.append(
|
||||
types.UserContent(
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
name=message["tool_call_id"],
|
||||
response={
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
)
|
||||
]
|
||||
parts = [
|
||||
types.Part.from_function_response(
|
||||
name=message["tool_call_id"],
|
||||
response={
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
)
|
||||
)
|
||||
]
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
if gemini_contents and isinstance(gemini_contents[0], types.ModelContent):
|
||||
gemini_contents.pop()
|
||||
|
||||
return gemini_contents
|
||||
|
||||
@@ -271,19 +333,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
||||
|
||||
if finish_reason in {
|
||||
types.FinishReason.PROHIBITED_CONTENT,
|
||||
types.FinishReason.SPII,
|
||||
types.FinishReason.BLOCKLIST,
|
||||
}:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
# 防止旧版本SDK不存在IMAGE_SAFETY
|
||||
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
||||
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
if not result_parts:
|
||||
logger.debug(result.candidates)
|
||||
@@ -313,9 +375,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
@@ -327,7 +387,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalities.append("Image")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
temperature=payloads.get("temperature", 0.7)
|
||||
temperature = payloads.get("temperature", 0.7)
|
||||
|
||||
result: Optional[types.GenerateContentResponse] = None
|
||||
while True:
|
||||
@@ -447,13 +507,15 @@ class ProviderGoogleGenAI(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
@@ -487,13 +549,15 @@ class ProviderGoogleGenAI(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
contexts: str = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
@@ -539,14 +603,14 @@ class ProviderGoogleGenAI(Provider):
|
||||
def get_current_key(self) -> str:
|
||||
return self.chosen_api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
def get_keys(self) -> list[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
async def assemble_context(self, text: str, image_urls: list[str] = None):
|
||||
"""
|
||||
组装上下文。
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import urllib.parse
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -23,7 +25,8 @@ class ProviderGSVITTS(TTSProvider):
|
||||
self.emotion = provider_config.get("emotion")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
params = {"text": text}
|
||||
|
||||
if self.character:
|
||||
|
||||
@@ -60,10 +60,12 @@ class LLMTunerModelLoader(Provider):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = [],
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
149
astrbot/core/provider/sources/minimax_tts_api_source.py
Normal file
149
astrbot/core/provider/sources/minimax_tts_api_source.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from typing import Dict, List, Union, AsyncIterator
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.api import logger
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"minimax_tts_api", "MiniMax TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.api_base: str = provider_config.get(
|
||||
"api_base", "https://api.minimax.chat/v1/t2a_v2"
|
||||
)
|
||||
self.group_id: str = provider_config.get("minimax-group-id", "")
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
|
||||
self.is_timber_weight: bool = provider_config.get(
|
||||
"minimax-is-timber-weight", False
|
||||
)
|
||||
self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads(
|
||||
provider_config.get(
|
||||
"minimax-timber-weight",
|
||||
'[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
|
||||
)
|
||||
)
|
||||
|
||||
self.voice_setting: dict = {
|
||||
"speed": provider_config.get("minimax-voice-speed", 1.0),
|
||||
"vol": provider_config.get("minimax-voice-vol", 1.0),
|
||||
"pitch": provider_config.get("minimax-voice-pitch", 0),
|
||||
"voice_id": ""
|
||||
if self.is_timber_weight
|
||||
else provider_config.get("minimax-voice-id", ""),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
||||
"latex_read": provider_config.get("minimax-voice-latex", False),
|
||||
"english_normalization": provider_config.get(
|
||||
"minimax-voice-english-normalization", False
|
||||
),
|
||||
}
|
||||
|
||||
self.audio_setting: dict = {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
}
|
||||
|
||||
self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
def _build_tts_stream_body(self, text: str):
|
||||
"""构建流式请求体"""
|
||||
dict_body: Dict[str, object] = {
|
||||
"model": self.model_name,
|
||||
"text": text,
|
||||
"stream": True,
|
||||
"language_boost": self.lang_boost,
|
||||
"voice_setting": self.voice_setting,
|
||||
"audio_setting": self.audio_setting,
|
||||
}
|
||||
if self.is_timber_weight:
|
||||
dict_body["timber_weights"] = self.timber_weight
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.concat_base_url,
|
||||
headers=self.headers,
|
||||
data=self._build_tts_stream_body(text),
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
buffer = b""
|
||||
while True:
|
||||
chunk = await response.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk
|
||||
|
||||
while b"\n\n" in buffer:
|
||||
try:
|
||||
message, buffer = buffer.split(b"\n\n", 1)
|
||||
if message.startswith(b"data: "):
|
||||
try:
|
||||
data = json.loads(message[6:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse JSON data from SSE message"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
buffer = buffer[-1024:]
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
|
||||
|
||||
async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
|
||||
"""解码数据流到 audio 比特流"""
|
||||
chunks = []
|
||||
async for chunk in audio_stream:
|
||||
if chunk.strip():
|
||||
chunks.append(bytes.fromhex(chunk.strip()))
|
||||
return b"".join(chunks)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
|
||||
|
||||
try:
|
||||
# 直接将异步生成器传递给 _audio_play 方法
|
||||
audio_stream = self._call_tts_stream(text)
|
||||
audio = await self._audio_play(audio_stream)
|
||||
|
||||
# 结果保存至文件
|
||||
with open(path, "wb") as file:
|
||||
file.write(audio)
|
||||
|
||||
return path
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise e
|
||||
43
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
43
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from openai import AsyncOpenAI
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"openai_embedding",
|
||||
"OpenAI API Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
base_url=provider_config.get(
|
||||
"embedding_api_base", "https://api.openai.com/v1"
|
||||
),
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -21,7 +21,7 @@ from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List, AsyncGenerator
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -195,7 +195,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
# workaround for #1454
|
||||
if isinstance(tool_call.function.arguments, str):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
else:
|
||||
args = tool_call.function.arguments
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
tool_call_ids.append(tool_call.id)
|
||||
@@ -221,14 +225,16 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
@@ -337,11 +343,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
@@ -362,7 +368,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
@@ -376,6 +382,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
@@ -398,7 +405,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
@@ -428,7 +437,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
@@ -443,6 +452,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
@@ -465,7 +475,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
"""
|
||||
@@ -505,7 +517,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {"role": "user", "content": [{"type": "text", "text": text if text else "[图片]"}]}
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -31,7 +33,8 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name, voice=self.voice, response_format="wav", input=text
|
||||
) as response:
|
||||
|
||||
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
|
||||
@register_provider_adapter(
|
||||
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderVolcengineTTS(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.api_key = provider_config.get("api_key", "")
|
||||
self.appid = provider_config.get("appid", "")
|
||||
self.cluster = provider_config.get("volcengine_cluster", "")
|
||||
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
||||
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
||||
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
|
||||
def _build_request_payload(self, text: str) -> dict:
|
||||
return {
|
||||
"app": {
|
||||
"appid": self.appid,
|
||||
"token": self.api_key,
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4())
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
"encoding": "mp3",
|
||||
"speed_ratio": self.speed_ratio,
|
||||
"volume_ratio": 1.0,
|
||||
"pitch_ratio": 1.0,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "query",
|
||||
"with_frontend": 1,
|
||||
"frontend_type": "unitTson"
|
||||
}
|
||||
}
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""异步方法获取语音文件路径"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer; {self.api_key}"
|
||||
}
|
||||
|
||||
payload = self._build_request_payload(text)
|
||||
|
||||
logger.debug(f"请求头: {headers}")
|
||||
logger.debug(f"请求 URL: {self.api_base}")
|
||||
logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.api_base,
|
||||
data=json.dumps(payload),
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
logger.debug(f"响应状态码: {response.status}")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(f"响应内容: {response_text[:200]}...")
|
||||
|
||||
if response.status == 200:
|
||||
resp_data = json.loads(response_text)
|
||||
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data)
|
||||
)
|
||||
|
||||
return file_path
|
||||
else:
|
||||
error_msg = resp_data.get("message", "未知错误")
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
||||
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -50,7 +51,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -61,7 +63,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -53,7 +54,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -64,7 +66,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -31,10 +31,12 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class SimpleOpenAIEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
@@ -1,94 +0,0 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
|
||||
class KnowledgeDBManager:
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(
|
||||
name, cfg["embedding_config"]
|
||||
)
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [
|
||||
f
|
||||
for f in os.listdir(self.db_path)
|
||||
if os.path.isfile(os.path.join(self.db_path, f))
|
||||
]
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
"""
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]["strategy"]:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class Store:
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
@@ -1,42 +0,0 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path="data/long_term_memory_chroma.db"
|
||||
)
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None),
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding,
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, query: str, top_n=3, metadata_filter: Dict = None
|
||||
) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding, n_results=top_n, where=metadata_filter
|
||||
)
|
||||
return results["documents"][0]
|
||||
@@ -5,6 +5,7 @@
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def load_config(namespace: str) -> Union[dict, bool]:
|
||||
@@ -13,7 +14,7 @@ def load_config(namespace: str) -> Union[dict, bool]:
|
||||
namespace: str, 配置的唯一识别符,也就是配置文件的名字。
|
||||
返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
@@ -43,7 +44,10 @@ def put_config(namespace: str, name: str, key: str, value, description: str):
|
||||
raise ValueError("key 只支持 str 类型。")
|
||||
if not isinstance(value, (str, int, float, bool, list)):
|
||||
raise ValueError("value 只支持 str, int, float, bool, list 类型。")
|
||||
path = f"data/config/{namespace}.json"
|
||||
|
||||
config_dir = os.path.join(get_astrbot_data_path(), "config")
|
||||
path = os.path.join(config_dir, f"{namespace}.json")
|
||||
|
||||
if not os.path.exists(path):
|
||||
with open(path, "w", encoding="utf-8-sig") as f:
|
||||
f.write("{}")
|
||||
@@ -71,7 +75,7 @@ def update_config(namespace: str, key: str, value):
|
||||
key: str, 配置项的键。
|
||||
value: str, int, float, bool, list, 配置项的值。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。")
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -16,7 +17,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
@@ -42,6 +42,8 @@ class Context:
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
@@ -54,14 +56,12 @@ class Context:
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
@@ -126,11 +126,8 @@ class Context:
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
@@ -144,24 +141,46 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_using_tts_provider(self) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
|
||||
def get_using_stt_provider(self) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
@@ -301,3 +320,12 @@ class Context:
|
||||
注册一个异步任务。
|
||||
"""
|
||||
self._register_tasks.append(task)
|
||||
|
||||
def register_web_api(
|
||||
self, route: str, view_handler: Awaitable, methods: list, desc: str
|
||||
):
|
||||
for idx, api in enumerate(self.registered_web_apis):
|
||||
if api[0] == route and methods == api[2]:
|
||||
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
|
||||
return
|
||||
self.registered_web_apis.append((route, view_handler, methods, desc))
|
||||
|
||||
@@ -7,6 +7,9 @@ from astrbot.core.config import AstrBotConfig
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
class GreedyStr(str):
|
||||
"""标记指令完成其他参数接收后的所有剩余文本。"""
|
||||
pass
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter):
|
||||
@@ -68,7 +71,22 @@ class CommandFilter(HandlerFilter):
|
||||
) -> Dict[str, Any]:
|
||||
"""将参数列表 params 根据 param_type 转换为参数字典。"""
|
||||
result = {}
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
param_items = list(param_type.items())
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_items):
|
||||
is_greedy = param_type_or_default_val is GreedyStr
|
||||
|
||||
if is_greedy:
|
||||
# GreedyStr 必须是最后一个参数
|
||||
if i != len(param_items) - 1:
|
||||
raise ValueError(
|
||||
f"参数 '{param_name}' (GreedyStr) 必须是最后一个参数。"
|
||||
)
|
||||
|
||||
# 将剩余的所有部分合并成一个字符串
|
||||
remaining_params = params[i:]
|
||||
result[param_name] = " ".join(remaining_params)
|
||||
break
|
||||
# 没有 GreedyStr 的情况
|
||||
if i >= len(params):
|
||||
if (
|
||||
isinstance(param_type_or_default_val, Type)
|
||||
|
||||
@@ -113,7 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
)
|
||||
raise ValueError(
|
||||
f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||
+ tree
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import heapq
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
@@ -8,100 +7,66 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
"""用于存储所有的 Star Handler"""
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
"""用于快速查找。key 是 handler_full_name"""
|
||||
_handlers = []
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
self._handlers: List[StarHandlerMetadata] = []
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
"""添加一个 Handler"""
|
||||
"""添加一个 Handler,并保持按优先级有序"""
|
||||
if "priority" not in handler.extras_configs:
|
||||
handler.extras_configs["priority"] = 0
|
||||
|
||||
heapq.heappush(self._handlers, (-handler.extras_configs["priority"], handler))
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
self._handlers.append(handler)
|
||||
self._handlers.sort(key=lambda h: -h.extras_configs["priority"])
|
||||
|
||||
def _print_handlers(self):
|
||||
"""打印所有的 Handler"""
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过事件类型获取 Handler
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
only_activated: 是否只返回已激活的插件的处理器
|
||||
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||
|
||||
Returns:
|
||||
List[StarHandlerMetadata]: 处理器列表
|
||||
"""
|
||||
handlers = []
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
|
||||
# 只激活的插件处理器
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
|
||||
# 平台兼容性过滤
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
continue
|
||||
|
||||
handlers.append(handler)
|
||||
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
"""通过 Handler 的全名获取 Handler"""
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过模块名获取 Handler"""
|
||||
return [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
handler for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
def clear(self):
|
||||
"""清空所有的 Handler"""
|
||||
self.star_handlers_map.clear()
|
||||
self._handlers.clear()
|
||||
|
||||
def remove(self, handler: StarHandlerMetadata):
|
||||
"""删除一个 Handler"""
|
||||
# self._handlers.remove(handler)
|
||||
for i, h in enumerate(self._handlers):
|
||||
if h[1] == handler:
|
||||
self._handlers.pop(i)
|
||||
break
|
||||
try:
|
||||
del self.star_handlers_map[handler.handler_full_name]
|
||||
except KeyError:
|
||||
pass
|
||||
self.star_handlers_map.pop(handler.handler_full_name, None)
|
||||
self._handlers = [h for h in self._handlers if h != handler]
|
||||
|
||||
def __iter__(self):
|
||||
"""使 StarHandlerRegistry 支持迭代"""
|
||||
return (handler for _, handler in self._handlers)
|
||||
return iter(self._handlers)
|
||||
|
||||
def __len__(self):
|
||||
"""返回 Handler 的数量"""
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
|
||||
@@ -2,28 +2,46 @@
|
||||
插件的重载、启停、安装、卸载等操作。
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
import asyncio
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
import yaml
|
||||
|
||||
from astrbot.core import logger, pip_installer, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
|
||||
from . import StarMetadata
|
||||
from .context import Context
|
||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||
from .star import star_map, star_registry
|
||||
from .star_handler import star_handlers_registry
|
||||
from .updator import PluginUpdator
|
||||
|
||||
try:
|
||||
from watchfiles import PythonFilter, awatch
|
||||
except ImportError:
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
@@ -34,17 +52,9 @@ class PluginManager:
|
||||
self.context._star_manager = self
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/config"
|
||||
)
|
||||
)
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
@@ -56,6 +66,58 @@ class PluginManager:
|
||||
"""插件配置 Schema 文件名"""
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
asyncio.create_task(self._watch_plugins_changes())
|
||||
|
||||
async def _watch_plugins_changes(self):
|
||||
"""监视插件文件变化"""
|
||||
try:
|
||||
async for changes in awatch(
|
||||
self.plugin_store_path,
|
||||
self.reserved_plugin_path,
|
||||
watch_filter=PythonFilter(),
|
||||
recursive=True,
|
||||
):
|
||||
# 处理文件变化
|
||||
await self._handle_file_changes(changes)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"插件热重载监视任务异常: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _handle_file_changes(self, changes):
|
||||
"""处理文件变化"""
|
||||
logger.info(f"检测到文件变化: {changes}")
|
||||
plugins_to_check = []
|
||||
|
||||
for star in star_registry:
|
||||
if not star.activated:
|
||||
continue
|
||||
if star.root_dir_name is None:
|
||||
continue
|
||||
if star.reserved:
|
||||
plugin_dir_path = os.path.join(
|
||||
self.reserved_plugin_path, star.root_dir_name
|
||||
)
|
||||
else:
|
||||
plugin_dir_path = os.path.join(
|
||||
self.plugin_store_path, star.root_dir_name
|
||||
)
|
||||
plugins_to_check.append((plugin_dir_path, star.name))
|
||||
reloaded_plugins = set()
|
||||
for change in changes:
|
||||
_, file_path = change
|
||||
for plugin_dir_path, plugin_name in plugins_to_check:
|
||||
if (
|
||||
os.path.commonpath([plugin_dir_path])
|
||||
== os.path.commonpath([plugin_dir_path, file_path])
|
||||
and plugin_name not in reloaded_plugins
|
||||
):
|
||||
logger.info(f"检测到插件 {plugin_name} 文件变化,正在重载...")
|
||||
await self.reload(plugin_name)
|
||||
reloaded_plugins.add(plugin_name)
|
||||
break
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
||||
@@ -104,7 +166,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -123,7 +185,7 @@ class PluginManager:
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
pip_installer.install(requirements_path=pth)
|
||||
await pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
@@ -345,7 +407,7 @@ class PluginManager:
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# 尝试安装依赖
|
||||
self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
await self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -389,11 +451,11 @@ class PluginManager:
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
# metadata.config = plugin_config
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -580,16 +642,17 @@ class PluginManager:
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path):
|
||||
if os.path.exists(readme_path) and nh3:
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
plugin_info = {"repo": plugin.repo, "readme": cleaned_content}
|
||||
|
||||
return plugin_info
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -6,7 +8,7 @@ from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map
|
||||
from pathlib import Path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class StarTools:
|
||||
@@ -180,7 +182,7 @@ class StarTools:
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path("data/plugin_data") / plugin_name
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -6,22 +6,20 @@ from ..updator import RepoZipUpdator
|
||||
from astrbot.core.utils.io import remove_dir, on_error
|
||||
from ..star.star import StarMetadata
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path
|
||||
|
||||
|
||||
class PluginUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
async def install(self, repo_url: str, proxy="") -> str:
|
||||
repo_name = self.format_repo_name(repo_url)
|
||||
_, repo_name, _ = self.parse_github_url(repo_url)
|
||||
repo_name = self.format_name(repo_name)
|
||||
plugin_path = os.path.join(self.plugin_store_path, repo_name)
|
||||
await self.download_from_repo_url(plugin_path, repo_url, proxy)
|
||||
self.unzip_file(plugin_path + ".zip", plugin_path)
|
||||
@@ -34,10 +32,6 @@ class PluginUpdator(RepoZipUpdator):
|
||||
if not repo_url:
|
||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
repo_url = f"{proxy}/{repo_url}"
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
@@ -57,7 +51,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
def unzip_file(self, zip_path: str, target_dir: str):
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
update_dir = ""
|
||||
logger.info(f"解压文件: {zip_path}")
|
||||
logger.info(f"正在解压压缩包: {zip_path}")
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
update_dir = z.namelist()[0]
|
||||
z.extractall(target_dir)
|
||||
|
||||
@@ -6,6 +6,7 @@ from .zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
|
||||
class AstrBotUpdator(RepoZipUpdator):
|
||||
@@ -16,9 +17,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
|
||||
)
|
||||
self.MAIN_PATH = get_astrbot_path()
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
@@ -45,13 +44,26 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
def _reboot(self, delay: int = 3):
|
||||
"""重启当前程序
|
||||
在指定的延迟后,终止所有子进程并重新启动程序
|
||||
这里只能使用 os.exec* 来重启程序
|
||||
"""
|
||||
py = sys.executable
|
||||
time.sleep(delay)
|
||||
self.terminate_child_processes()
|
||||
py = py.replace(" ", "\\ ")
|
||||
if os.name == "nt":
|
||||
py = f'"{sys.executable}"'
|
||||
else:
|
||||
py = sys.executable
|
||||
|
||||
try:
|
||||
os.execl(py, py, *sys.argv)
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
if os.name == "nt":
|
||||
args = [
|
||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||
]
|
||||
else:
|
||||
args = sys.argv[1:]
|
||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||
else:
|
||||
os.execl(sys.executable, py, *sys.argv)
|
||||
except Exception as e:
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
@@ -67,6 +79,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
file_url = None
|
||||
|
||||
if os.environ.get("ASTRBOT_CLI"):
|
||||
raise Exception("不支持更新CLI启动的AstrBot") # 避免版本管理混乱
|
||||
|
||||
if latest:
|
||||
latest_version = update_data[0]["tag_name"]
|
||||
if self.compare_version(VERSION, latest_version) >= 0:
|
||||
|
||||
41
astrbot/core/utils/astrbot_path.py
Normal file
41
astrbot/core/utils/astrbot_path.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Astrbot统一路径获取
|
||||
|
||||
项目路径:固定为源码所在路径
|
||||
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_astrbot_path() -> str:
|
||||
"""获取Astrbot项目路径"""
|
||||
return os.path.realpath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../")
|
||||
)
|
||||
|
||||
|
||||
def get_astrbot_root() -> str:
|
||||
"""获取Astrbot根目录路径"""
|
||||
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||
return os.path.realpath(path)
|
||||
else:
|
||||
return os.path.realpath(os.getcwd())
|
||||
|
||||
|
||||
def get_astrbot_data_path() -> str:
|
||||
"""获取Astrbot数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
|
||||
|
||||
|
||||
def get_astrbot_config_path() -> str:
|
||||
"""获取Astrbot配置文件路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
@@ -14,6 +14,7 @@ import certifi
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def on_error(func, path, exc_info):
|
||||
@@ -49,11 +50,11 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
|
||||
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
for f in os.listdir(temp_dir):
|
||||
path = os.path.join(temp_dir, f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600 * 12:
|
||||
@@ -63,7 +64,7 @@ def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
p = os.path.join(temp_dir, f"{timestamp}.jpg")
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
@@ -201,28 +202,29 @@ def get_local_ip_addresses():
|
||||
|
||||
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if os.path.exists(dist_dir):
|
||||
version_file = os.path.join(dist_dir, "assets", "version")
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
async def download_dashboard():
|
||||
async def download_dashboard(path: str = None, extract_path: str = "data"):
|
||||
"""下载管理面板文件"""
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = (
|
||||
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
)
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
with zipfile.ZipFile(path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user