Compare commits
451 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c3d98acbe | ||
|
|
7311786f48 | ||
|
|
82de9c926e | ||
|
|
7fd86d4de3 | ||
|
|
724da29e2a | ||
|
|
54113d7b94 | ||
|
|
66396e8290 | ||
|
|
72be76215f | ||
|
|
ace86703a9 | ||
|
|
7b25495463 | ||
|
|
3d4b651c1f | ||
|
|
d305ae064d | ||
|
|
ac4f3d8907 | ||
|
|
af2687771b | ||
|
|
a67b7f909a | ||
|
|
f9c3e4cdb0 | ||
|
|
dc62c1f8d4 | ||
|
|
0441b51a68 | ||
|
|
5c0c9f687e | ||
|
|
e049c54043 | ||
|
|
99e47540d5 | ||
|
|
8e1885ffeb | ||
|
|
8501a0c205 | ||
|
|
797f2a3173 | ||
|
|
1057b4bc35 | ||
|
|
efc0116595 | ||
|
|
cdc560fad0 | ||
|
|
75a2803710 | ||
|
|
fb3169faa4 | ||
|
|
d587bd837e | ||
|
|
b9fab74edc | ||
|
|
50c22bbadb | ||
|
|
d0b10b9195 | ||
|
|
c8fe4f4a3c | ||
|
|
a8ba0720af | ||
|
|
745a01246c | ||
|
|
bee5d3550f | ||
|
|
1789393151 | ||
|
|
345afe1338 | ||
|
|
65428aa49f | ||
|
|
b251ee9322 | ||
|
|
04f00682a0 | ||
|
|
90dcda1475 | ||
|
|
f1ee4eb89f | ||
|
|
343fc22168 | ||
|
|
00ef0d7e3d | ||
|
|
f2deaf6199 | ||
|
|
617a2c010e | ||
|
|
38eae1d1ee | ||
|
|
7e4c89b0cb | ||
|
|
14c29f07bd | ||
|
|
825e3dbcf5 | ||
|
|
8275130f04 | ||
|
|
2c47abea95 | ||
|
|
85aa28d724 | ||
|
|
53a3736b04 | ||
|
|
86ba3c230e | ||
|
|
8d21126bd6 | ||
|
|
74ded91976 | ||
|
|
7c27520d57 | ||
|
|
b54bbc4c5a | ||
|
|
3e09a4ddd4 | ||
|
|
f93f04a536 | ||
|
|
b93f30b809 | ||
|
|
95bd2f26a5 | ||
|
|
7cfcf056f9 | ||
|
|
96b565e1e8 | ||
|
|
9d7ad7a18f | ||
|
|
9838c2758b | ||
|
|
1b1f5f5a5e | ||
|
|
0f95f62aa1 | ||
|
|
9405ba7871 | ||
|
|
60b2ff0a7a | ||
|
|
e6c8507379 | ||
|
|
420db5416e | ||
|
|
6e03218d54 | ||
|
|
5e4bd36b26 | ||
|
|
bbc039366e | ||
|
|
e1ec7dbbba | ||
|
|
075b008740 | ||
|
|
b2c382fa01 | ||
|
|
c5f9b5861f | ||
|
|
2dace4c697 | ||
|
|
c7891385ca | ||
|
|
2059ddcadf | ||
|
|
ba1b68df20 | ||
|
|
403b61836d | ||
|
|
b5af7d1eb9 | ||
|
|
f453af6e4c | ||
|
|
64245d001c | ||
|
|
7d92965cae | ||
|
|
b4fa08c4e2 | ||
|
|
d4e9566851 | ||
|
|
a26b494f7f | ||
|
|
b84e22e41f | ||
|
|
cee6efab19 | ||
|
|
30f71cb550 | ||
|
|
771e755a78 | ||
|
|
16ec462abd | ||
|
|
ca55465d3c | ||
|
|
7098c98dde | ||
|
|
f56355da89 | ||
|
|
422160debd | ||
|
|
8062cf406a | ||
|
|
0e802232ec | ||
|
|
f650a9205d | ||
|
|
c85dbb2347 | ||
|
|
a6a79128c8 | ||
|
|
42839627e8 | ||
|
|
e7f35098e4 | ||
|
|
267e68a894 | ||
|
|
b32b444438 | ||
|
|
522d0f8313 | ||
|
|
5715e5de67 | ||
|
|
cc6b05e8b3 | ||
|
|
417747d5d0 | ||
|
|
a34f439226 | ||
|
|
b7ca014fd0 | ||
|
|
fa098d585a | ||
|
|
c35a14e3ec | ||
|
|
60651736a5 | ||
|
|
581f9b7bd3 | ||
|
|
124eb04807 | ||
|
|
1d561da7fb | ||
|
|
16e3cd0784 | ||
|
|
a6d91933dc | ||
|
|
445c40f758 | ||
|
|
725a841a3b | ||
|
|
f77c453843 | ||
|
|
ba6718d5bc | ||
|
|
cdb7a1b3fa | ||
|
|
a03c79b89d | ||
|
|
98800d3426 | ||
|
|
a616adaac4 | ||
|
|
ffb5605c99 | ||
|
|
621b556856 | ||
|
|
a3ffecbb2a | ||
|
|
ea64cebe2a | ||
|
|
e79487dd5f | ||
|
|
7fe1c1ec89 | ||
|
|
ab2bbff369 | ||
|
|
ec32825309 | ||
|
|
fd0c182087 | ||
|
|
49fcff1daf | ||
|
|
33b64ddf39 | ||
|
|
4c447aa648 | ||
|
|
ccbfc3d274 | ||
|
|
f83fe43bbb | ||
|
|
19022d67f8 | ||
|
|
58a815dd6b | ||
|
|
bc9fe82860 | ||
|
|
b3cd9bf2b9 | ||
|
|
c5c2b829ec | ||
|
|
9713f96401 | ||
|
|
11f35ebf96 | ||
|
|
7d403aa181 | ||
|
|
64af810a4a | ||
|
|
30821905af | ||
|
|
a9dbff756b | ||
|
|
a6aba10d3d | ||
|
|
9c276c37fe | ||
|
|
6ab6c0fd4c | ||
|
|
b6b0fe3fff | ||
|
|
0d5825bda9 | ||
|
|
cdfb64631a | ||
|
|
d161c281c8 | ||
|
|
8fed5bf2a1 | ||
|
|
98d2e9bd27 | ||
|
|
a03af55edd | ||
|
|
86e2fd9aee | ||
|
|
97bd0e5e58 | ||
|
|
ceaba21986 | ||
|
|
172a77d942 | ||
|
|
4f9d2d2a7d | ||
|
|
8c929f6e05 | ||
|
|
3319b71f5b | ||
|
|
46ec028a5b | ||
|
|
0ce0ef3e5c | ||
|
|
375b071cb2 | ||
|
|
29e1417ff2 | ||
|
|
75db2bd366 | ||
|
|
60ca1efbda | ||
|
|
2692e4978b | ||
|
|
91982eb002 | ||
|
|
bb1dec76fa | ||
|
|
f618b8fcdc | ||
|
|
9147cab75b | ||
|
|
5f07bcc8e6 | ||
|
|
705cf2ea1b | ||
|
|
42c4394484 | ||
|
|
221221a3c1 | ||
|
|
9564166297 | ||
|
|
f5cf3c3c8e | ||
|
|
18f919fb6b | ||
|
|
0924835253 | ||
|
|
20d2e5c578 | ||
|
|
907801605c | ||
|
|
93bc684e8c | ||
|
|
a76c98d57e | ||
|
|
d937a800d0 | ||
|
|
d16f3a227f | ||
|
|
80c9a3eeda | ||
|
|
e68173b451 | ||
|
|
40c27d87f5 | ||
|
|
3c13b5049d | ||
|
|
8288d5e51f | ||
|
|
6e1449900a | ||
|
|
4ffbb18ab4 | ||
|
|
b27271b7a3 | ||
|
|
ebb6665f64 | ||
|
|
e4e5731ffd | ||
|
|
2ab5810f13 | ||
|
|
af934c5d09 | ||
|
|
1e0cf7c112 | ||
|
|
46859c93c9 | ||
|
|
ea1f9cb3b2 | ||
|
|
1641549016 | ||
|
|
716a5dbb8a | ||
|
|
af98cb11c5 | ||
|
|
9a4c2cf341 | ||
|
|
2bc3bcd102 | ||
|
|
d6c663f79d | ||
|
|
9ed86e5f53 | ||
|
|
303e0bc037 | ||
|
|
2cc24019f9 | ||
|
|
83ce774d19 | ||
|
|
2b4ee13b5e | ||
|
|
3a964561f0 | ||
|
|
6959f86632 | ||
|
|
537d373e10 | ||
|
|
cceadf222c | ||
|
|
cf5a4af623 | ||
|
|
39aea11c22 | ||
|
|
c2f1227700 | ||
|
|
900f14d37c | ||
|
|
598249b1d6 | ||
|
|
7ed15bdf04 | ||
|
|
2fc0ec0f72 | ||
|
|
5e9c2a669b | ||
|
|
b310521884 | ||
|
|
288945bf7e | ||
|
|
4fc07cff36 | ||
|
|
b884fe0e86 | ||
|
|
855858c236 | ||
|
|
c11a2a5419 | ||
|
|
773a6572af | ||
|
|
88ad373c9b | ||
|
|
51666464b9 | ||
|
|
5af9cf2f52 | ||
|
|
12c4ae4b10 | ||
|
|
4e1bef414a | ||
|
|
e896c18644 | ||
|
|
c852685e74 | ||
|
|
1e99797df8 | ||
|
|
52a4c986a8 | ||
|
|
c501728204 | ||
|
|
6b067fa6a7 | ||
|
|
a1cd5c53a9 | ||
|
|
a46d487e03 | ||
|
|
3deb6d3ab3 | ||
|
|
af34cdd5d2 | ||
|
|
6e1393235a | ||
|
|
343e0b54b9 | ||
|
|
ecb70cb6f7 | ||
|
|
ca50618af6 | ||
|
|
29c07ba83e | ||
|
|
45fbb83a9f | ||
|
|
ae7ba2df25 | ||
|
|
c3ef57cc32 | ||
|
|
7bb4ca5a14 | ||
|
|
063783d81d | ||
|
|
42116c9b65 | ||
|
|
a36e11973d | ||
|
|
5125568ea2 | ||
|
|
0fa164e50d | ||
|
|
cf814e81ee | ||
|
|
43a45f18ce | ||
|
|
ad51381063 | ||
|
|
0b0e4ce904 | ||
|
|
6a3e04d688 | ||
|
|
4107a17370 | ||
|
|
06b4d8f169 | ||
|
|
1c0c820746 | ||
|
|
d061403a28 | ||
|
|
5c092321a6 | ||
|
|
bdd3f61c1f | ||
|
|
8023557d6e | ||
|
|
074b0ced7a | ||
|
|
3864b1ac9b | ||
|
|
6e9b43457d | ||
|
|
ca1aec8920 | ||
|
|
acac580862 | ||
|
|
673e1b2980 | ||
|
|
f62157be72 | ||
|
|
f894ecf3b6 | ||
|
|
66dd4e28ad | ||
|
|
939dc1b0fb | ||
|
|
56bf5d38a1 | ||
|
|
d09b70b295 | ||
|
|
205180387a | ||
|
|
39c8cfeda5 | ||
|
|
f38a329be5 | ||
|
|
a0cd069539 | ||
|
|
bf306a2f01 | ||
|
|
c31f93a8d1 | ||
|
|
4730ab6309 | ||
|
|
1ae78ca98c | ||
|
|
d2379da478 | ||
|
|
0f64981b20 | ||
|
|
0002e49bb5 | ||
|
|
db13a60274 | ||
|
|
db0f11a359 | ||
|
|
ac7f43520b | ||
|
|
f67b9f5f6e | ||
|
|
c75156c4ce | ||
|
|
10270b5595 | ||
|
|
f7458572ed | ||
|
|
d57b7222b2 | ||
|
|
62e70a673a | ||
|
|
5e9eba6478 | ||
|
|
cb02dfe1a4 | ||
|
|
b50739e1af | ||
|
|
8da1b0212d | ||
|
|
ca1f2acb33 | ||
|
|
c15f966669 | ||
|
|
7705b8781a | ||
|
|
b2502746f0 | ||
|
|
ab68094386 | ||
|
|
bbec701223 | ||
|
|
b29d14e600 | ||
|
|
86e51c5cd1 | ||
|
|
cb8267be3f | ||
|
|
eaed43915c | ||
|
|
bd91fd2c38 | ||
|
|
1203b214cd | ||
|
|
c3fec15f11 | ||
|
|
0545653494 | ||
|
|
db2989bdb4 | ||
|
|
587bd00a19 | ||
|
|
960ff438e8 | ||
|
|
98e7ea85d3 | ||
|
|
2549e44710 | ||
|
|
4d32b563ca | ||
|
|
3a4b732977 | ||
|
|
500909a28e | ||
|
|
07753eb25b | ||
|
|
c6eaf3d010 | ||
|
|
6723fe8271 | ||
|
|
3348b70435 | ||
|
|
35a8527c16 | ||
|
|
7afc475290 | ||
|
|
789bceaa3a | ||
|
|
abbc043969 | ||
|
|
654e5762f1 | ||
|
|
507c3e3629 | ||
|
|
991dfeb2f2 | ||
|
|
26482fc2d3 | ||
|
|
e0ce6d9688 | ||
|
|
946595216a | ||
|
|
864b6bc56d | ||
|
|
6ea5b7581f | ||
|
|
f70b8f0c10 | ||
|
|
1593bcb537 | ||
|
|
bf7fc02c8d | ||
|
|
143702b92b | ||
|
|
c5ccc1a084 | ||
|
|
2ecb52a9b2 | ||
|
|
6439917cbe | ||
|
|
d21c18f657 | ||
|
|
25ef0039e4 | ||
|
|
e6981290bc | ||
|
|
75c3d8abbd | ||
|
|
d88683f498 | ||
|
|
40b9aa3a4c | ||
|
|
b6d1515d58 | ||
|
|
e01d4264e3 | ||
|
|
2117b65487 | ||
|
|
a7823b352f | ||
|
|
c543b62a08 | ||
|
|
3923b87f08 | ||
|
|
b7ecdadb83 | ||
|
|
5ff121e1ed | ||
|
|
f486e5448f | ||
|
|
c5aae98558 | ||
|
|
6d8a3b9897 | ||
|
|
6d98780e19 | ||
|
|
3ad2c46f3f | ||
|
|
a730cee7fd | ||
|
|
77c823c100 | ||
|
|
124f21c67a | ||
|
|
e46cf20dd3 | ||
|
|
4bef5e8313 | ||
|
|
22e93b0af4 | ||
|
|
5aeca9662b | ||
|
|
b996cf1f05 | ||
|
|
878a106877 | ||
|
|
45d36f86fd | ||
|
|
b108ae403a | ||
|
|
887ed66768 | ||
|
|
dac840a887 | ||
|
|
238de4ba8c | ||
|
|
9a7bdade43 | ||
|
|
aa84556204 | ||
|
|
6b68069fcd | ||
|
|
42c7034fb2 | ||
|
|
060c7e0145 | ||
|
|
b5b085dfb1 | ||
|
|
fc06ce9d7f | ||
|
|
d8d81b05a7 | ||
|
|
a60f42b1f2 | ||
|
|
6e18be88d0 | ||
|
|
b45e439c48 | ||
|
|
b87061c18c | ||
|
|
f78aca7752 | ||
|
|
3ccca2aa10 | ||
|
|
6d7c40eb76 | ||
|
|
da4cd7fb65 | ||
|
|
c97cda6b84 | ||
|
|
7a7fd4167a | ||
|
|
dffc1a43d5 | ||
|
|
36897fea1e | ||
|
|
c7b34735f0 | ||
|
|
5b07176c88 | ||
|
|
474b40d660 | ||
|
|
a62901b948 | ||
|
|
25d8746327 | ||
|
|
aff1698223 | ||
|
|
7f8941745f | ||
|
|
b858401098 | ||
|
|
d5a158b80f | ||
|
|
f315f284aa | ||
|
|
c367f5009d | ||
|
|
6db1e63bda | ||
|
|
e22ab2ede6 | ||
|
|
b7d7e0b682 | ||
|
|
96bba15f2f | ||
|
|
fcf965a595 | ||
|
|
e1a20d3c22 | ||
|
|
2abd7d8c5d | ||
|
|
5b8f73cdd7 | ||
|
|
7fd765421f | ||
|
|
d9d94af022 | ||
|
|
790b924e57 | ||
|
|
4a62f877df | ||
|
|
ac47c57bb7 | ||
|
|
ccdbb01513 | ||
|
|
5206d750ac | ||
|
|
a800e3df67 | ||
|
|
ccb1f87a20 | ||
|
|
c111da4681 | ||
|
|
f06be6ed21 |
@@ -20,4 +20,5 @@ dashboard/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
30
.github/workflows/auto_release.yml
vendored
30
.github/workflows/auto_release.yml
vendored
@@ -23,6 +23,36 @@ jobs:
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo ${{ github.ref_name }} > dist/assets/version
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Upload to Cloudflare R2
|
||||
env:
|
||||
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
|
||||
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
|
||||
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
|
||||
R2_BUCKET_NAME: "astrbot"
|
||||
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
|
||||
VERSION_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo "Installing rclone..."
|
||||
curl https://rclone.org/install.sh | sudo bash
|
||||
|
||||
echo "Configuring rclone remote..."
|
||||
mkdir -p ~/.config/rclone
|
||||
cat <<EOF > ~/.config/rclone/rclone.conf
|
||||
[r2]
|
||||
type = s3
|
||||
provider = Cloudflare
|
||||
access_key_id = $R2_ACCESS_KEY_ID
|
||||
secret_access_key = $R2_SECRET_ACCESS_KEY
|
||||
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
|
||||
EOF
|
||||
|
||||
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
|
||||
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
|
||||
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
|
||||
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
|
||||
|
||||
- name: Fetch Changelog
|
||||
run: |
|
||||
|
||||
6
.github/workflows/dashboard_ci.yml
vendored
6
.github/workflows/dashboard_ci.yml
vendored
@@ -1,6 +1,10 @@
|
||||
name: AstrBot Dashboard CI
|
||||
|
||||
on: [push]
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
35
.github/workflows/docker-image.yml
vendored
35
.github/workflows/docker-image.yml
vendored
@@ -11,24 +11,42 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: 拉取源码
|
||||
- name: Pull The Codes
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
|
||||
- name: 设置 QEMU
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
tag=$(git describe --tags --abbrev=0)
|
||||
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout to latest tag (only on manual trigger)
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
- name: Set Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: 登录到 DockerHub
|
||||
- name: Log in to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: 构建和推送 Docker hub
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: Soulter
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push Docker to DockerHub and Github GHCR
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -36,8 +54,9 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
ghcr.io/soulter/astrbot:latest
|
||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,3 +30,4 @@ packages/python_interpreter/workplace
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
86
README.md
86
README.md
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
@@ -31,13 +31,21 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
> [!NOTE]
|
||||
> [!WARNING]
|
||||
>
|
||||
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,我们正在评估其他方案(如 xxxbot 等)并将在数日内接入(很快!)。目前推荐微信用户暂时使用**微信官方**推出的企业微信接入方式和微信客服接入方式(版本 >= v3.5.7)。详情请前往 [#1443](https://github.com/AstrBotDevs/AstrBot/issues/1443) 讨论。
|
||||
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
<details><summary>1. AstrBot 现已自带知识库能力</summary>
|
||||
|
||||
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
@@ -45,7 +53,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
@@ -78,14 +86,29 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
#### 手动部署
|
||||
|
||||
推荐使用 `uv`。
|
||||
> 推荐使用 `uv`。
|
||||
|
||||
首先,安装 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
通过 Git Clone 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
pip install uv
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者,直接通过 uvx 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
mkdir astrbot && cd astrbot
|
||||
uvx astrbot init
|
||||
# uvx astrbot run
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
@@ -94,40 +117,50 @@ uv run main.py
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| Telegram | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 企业微信 | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 微信客服 | ✔ | 私聊 | 文字、图片 |
|
||||
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| 微信个人号 | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| 微信对话开放平台 | 🚧 |
|
||||
| WhatsApp | 🚧 |
|
||||
| 小爱音响 | 🚧 |
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| DashScope(阿里云百炼应用) | ✔ | LLMOps | |
|
||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -151,7 +184,6 @@ pre-commit install
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
|
||||
1
astrbot/cli/__init__.py
Normal file
1
astrbot/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "3.5.8"
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import click
|
||||
from pathlib import Path
|
||||
from astrbot.core.config.default import VERSION
|
||||
"""
|
||||
AstrBot CLI入口
|
||||
"""
|
||||
|
||||
import click
|
||||
import sys
|
||||
from . import __version__
|
||||
from .commands import init, run, plug, conf
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
@@ -14,210 +14,25 @@ logo_tmpl = r"""
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# utils
|
||||
def _get_astrbot_root(path: str | None) -> Path:
|
||||
"""获取astrbot根目录"""
|
||||
match path:
|
||||
case None:
|
||||
match ASTRBOT_ROOT := os.getenv("ASTRBOT_ROOT"):
|
||||
case None:
|
||||
astrbot_root = Path.cwd() / "data"
|
||||
case _:
|
||||
astrbot_root = Path(ASTRBOT_ROOT).resolve()
|
||||
case str():
|
||||
astrbot_root = Path(path).resolve()
|
||||
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
if not dot_astrbot.exists():
|
||||
if click.confirm(
|
||||
f"运行前必须先执行初始化!请检查当前目录是否正确,回车以继续: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
astrbot_root.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
|
||||
return astrbot_root
|
||||
|
||||
|
||||
# 通过类型来验证先后,必须先获取 Path 对象才能对该目录进行检查
|
||||
def _check_astrbot_root(astrbot_root: Path) -> None:
|
||||
"""验证"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
if not astrbot_root.exists():
|
||||
click.echo(f"AstrBot root directory does not exist: {astrbot_root}")
|
||||
click.echo("Please run 'astrbot init' to create the directory.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo(f"AstrBot root directory exists: {astrbot_root}")
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
else:
|
||||
click.echo(f"Welcome back! AstrBot root directory: {astrbot_root}")
|
||||
|
||||
|
||||
async def _check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
try:
|
||||
from ..core.utils.io import get_dashboard_version, download_dashboard
|
||||
except ImportError:
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
|
||||
try:
|
||||
# 添加 create=True 参数以确保在初始化时不会抛出异常
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
# 确保使用 create=True 参数
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if dashboard_version == f"v{VERSION}":
|
||||
click.echo("无需更新")
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
# 确保使用 create=True 参数
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
# 初始化模式下,下载到指定位置
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
@click.group(name="astrbot")
|
||||
@click.group()
|
||||
@click.version_option(__version__, prog_name="AstrBot")
|
||||
def cli() -> None:
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(logo_tmpl)
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot version: {VERSION}")
|
||||
click.echo(f"AstrBot CLI version: {__version__}")
|
||||
|
||||
|
||||
# region init
|
||||
@cli.command()
|
||||
@click.option("--path", "-p", help="AstrBot 数据目录")
|
||||
@click.option("--force", "-f", is_flag=True, help="强制初始化")
|
||||
def init(path: str | None, force: bool) -> None:
|
||||
"""Initialize AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = _get_astrbot_root(path)
|
||||
if force:
|
||||
if click.confirm(
|
||||
"强制初始化会删除当前目录下的所有文件,是否继续?",
|
||||
default=False,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在删除当前目录下的所有文件...")
|
||||
shutil.rmtree(astrbot_root, ignore_errors=True)
|
||||
|
||||
_check_astrbot_root(astrbot_root)
|
||||
|
||||
click.echo(f"AstrBot root directory: {astrbot_root}")
|
||||
|
||||
if not astrbot_root.exists():
|
||||
# 创建目录
|
||||
astrbot_root.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"Created directory: {astrbot_root}")
|
||||
else:
|
||||
click.echo(f"Directory already exists: {astrbot_root}")
|
||||
|
||||
config_path: Path = astrbot_root / "config"
|
||||
plugins_path: Path = astrbot_root / "plugins"
|
||||
temp_path: Path = astrbot_root / "temp"
|
||||
config_path.mkdir(parents=True, exist_ok=True)
|
||||
plugins_path.mkdir(parents=True, exist_ok=True)
|
||||
temp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
click.echo(f"Created directories: {config_path}, {plugins_path}, {temp_path}")
|
||||
|
||||
# 检查是否安装了dashboard
|
||||
asyncio.run(_check_dashboard(astrbot_root))
|
||||
|
||||
|
||||
# region run
|
||||
@cli.command()
|
||||
@click.option("--path", "-p", help="AstrBot 数据目录")
|
||||
def run(path: str | None = None) -> None:
|
||||
"""Run AstrBot"""
|
||||
# 解析为绝对路径
|
||||
try:
|
||||
from ..core.log import LogBroker
|
||||
from ..core import db_helper
|
||||
from ..core.initial_loader import InitialLoader
|
||||
except ImportError:
|
||||
from astrbot.core.log import LogBroker
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
astrbot_root = _get_astrbot_root(path)
|
||||
|
||||
_check_astrbot_root(astrbot_root)
|
||||
|
||||
asyncio.run(_check_dashboard(astrbot_root))
|
||||
|
||||
log_broker = LogBroker()
|
||||
db = db_helper
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
try:
|
||||
asyncio.run(core_lifecycle.start())
|
||||
except KeyboardInterrupt:
|
||||
click.echo("接收到退出信号,正在关闭 AstrBot...")
|
||||
except Exception as e:
|
||||
click.echo(f"运行时出现错误: {e}")
|
||||
|
||||
|
||||
# region Basic
|
||||
@cli.command(name="version")
|
||||
def version() -> None:
|
||||
"""Show the version of AstrBot"""
|
||||
click.echo(f"AstrBot version: {VERSION}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.command()
|
||||
@click.argument("command_name", required=False, type=str)
|
||||
def help(command_name: str | None) -> None:
|
||||
"""Show help information for commands
|
||||
"""显示命令的帮助信息
|
||||
|
||||
If COMMAND_NAME is provided, show detailed help for that command.
|
||||
Otherwise, show general help information.
|
||||
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
|
||||
否则,显示通用帮助信息。
|
||||
"""
|
||||
ctx = click.get_current_context()
|
||||
if command_name:
|
||||
@@ -234,5 +49,11 @@ def help(command_name: str | None) -> None:
|
||||
click.echo(cli.get_help(ctx))
|
||||
|
||||
|
||||
cli.add_command(init)
|
||||
cli.add_command(run)
|
||||
cli.add_command(help)
|
||||
cli.add_command(plug)
|
||||
cli.add_command(conf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
6
astrbot/cli/commands/__init__.py
Normal file
6
astrbot/cli/commands/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .cmd_init import init
|
||||
from .cmd_run import run
|
||||
from .cmd_plug import plug
|
||||
from .cmd_conf import conf
|
||||
|
||||
__all__ = ["init", "run", "plug", "conf"]
|
||||
206
astrbot/cli/commands/cmd_conf.py
Normal file
206
astrbot/cli/commands/cmd_conf.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
import click
|
||||
import hashlib
|
||||
import zoneinfo
|
||||
from typing import Any, Callable
|
||||
from ..utils import get_astrbot_root, check_astrbot_root
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
"""验证日志级别"""
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_port(value: str) -> int:
|
||||
"""验证 Dashboard 端口"""
|
||||
try:
|
||||
port = int(value)
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException("端口必须在 1-65535 范围内")
|
||||
return port
|
||||
except ValueError:
|
||||
raise click.ClickException("端口必须是数字")
|
||||
|
||||
|
||||
def _validate_dashboard_username(value: str) -> str:
|
||||
"""验证 Dashboard 用户名"""
|
||||
if not value:
|
||||
raise click.ClickException("用户名不能为空")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_password(value: str) -> str:
|
||||
"""验证 Dashboard 密码"""
|
||||
if not value:
|
||||
raise click.ClickException("密码不能为空")
|
||||
return hashlib.md5(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def _validate_timezone(value: str) -> str:
|
||||
"""验证时区"""
|
||||
try:
|
||||
zoneinfo.ZoneInfo(value)
|
||||
except Exception:
|
||||
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_callback_api_base(value: str) -> str:
|
||||
"""验证回调接口基址"""
|
||||
if not value.startswith("http://") and not value.startswith("https://"):
|
||||
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
|
||||
return value
|
||||
|
||||
|
||||
# 可通过CLI设置的配置项,配置键到验证器函数的映射
|
||||
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
"timezone": _validate_timezone,
|
||||
"log_level": _validate_log_level,
|
||||
"dashboard.port": _validate_dashboard_port,
|
||||
"dashboard.username": _validate_dashboard_username,
|
||||
"dashboard.password": _validate_dashboard_password,
|
||||
"callback_api_base": _validate_callback_api_base,
|
||||
}
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
"""加载或初始化配置文件"""
|
||||
root = get_astrbot_root()
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
config_path = root / "data" / "cmd_config.json"
|
||||
if not config_path.exists():
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
"""保存配置文件"""
|
||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||
)
|
||||
|
||||
|
||||
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""设置嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts[:-1]:
|
||||
if part not in obj:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
||||
)
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
|
||||
|
||||
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
"""获取嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj
|
||||
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf():
|
||||
"""配置管理命令
|
||||
|
||||
支持的配置项:
|
||||
|
||||
- timezone: 时区设置 (例如: Asia/Shanghai)
|
||||
|
||||
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
|
||||
- dashboard.port: Dashboard 端口
|
||||
|
||||
- dashboard.username: Dashboard 用户名
|
||||
|
||||
- dashboard.password: Dashboard 密码
|
||||
|
||||
- callback_api_base: 回调接口基址
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str):
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
config = _load_config()
|
||||
|
||||
try:
|
||||
old_value = _get_nested_item(config, key)
|
||||
validated_value = CONFIG_VALIDATORS[key](value)
|
||||
_set_nested_item(config, key, validated_value)
|
||||
_save_config(config)
|
||||
|
||||
click.echo(f"配置已更新: {key}")
|
||||
if key == "dashboard.password":
|
||||
click.echo(" 原值: ********")
|
||||
click.echo(" 新值: ********")
|
||||
else:
|
||||
click.echo(f" 原值: {old_value}")
|
||||
click.echo(f" 新值: {validated_value}")
|
||||
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"设置配置失败: {str(e)}")
|
||||
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str = None):
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
try:
|
||||
value = _get_nested_item(config, key)
|
||||
if key == "dashboard.password":
|
||||
value = "********"
|
||||
click.echo(f"{key}: {value}")
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"获取配置失败: {str(e)}")
|
||||
else:
|
||||
click.echo("当前配置:")
|
||||
for key in CONFIG_VALIDATORS.keys():
|
||||
try:
|
||||
value = (
|
||||
"********"
|
||||
if key == "dashboard.password"
|
||||
else _get_nested_item(config, key)
|
||||
)
|
||||
click.echo(f" {key}: {value}")
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
55
astrbot/cli/commands/cmd_init.py
Normal file
55
astrbot/cli/commands/cmd_init.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def initialize_astrbot(astrbot_root) -> None:
|
||||
"""执行 AstrBot 初始化逻辑"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(f"Current Directory: {astrbot_root}")
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
|
||||
paths = {
|
||||
"data": astrbot_root / "data",
|
||||
"config": astrbot_root / "data" / "config",
|
||||
"plugins": astrbot_root / "data" / "plugins",
|
||||
"temp": astrbot_root / "data" / "temp",
|
||||
}
|
||||
|
||||
for name, path in paths.items():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
|
||||
@click.command()
|
||||
def init() -> None:
|
||||
"""初始化 AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = get_astrbot_root()
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"初始化失败: {e!s}")
|
||||
247
astrbot/cli/commands/cmd_plug.py
Normal file
247
astrbot/cli/commands/cmd_plug.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import shutil
|
||||
|
||||
|
||||
from ..utils import (
|
||||
get_git_repo,
|
||||
build_plug_list,
|
||||
manage_plugin,
|
||||
PluginStatus,
|
||||
check_astrbot_root,
|
||||
get_astrbot_root,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
def plug():
|
||||
"""插件管理"""
|
||||
pass
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = get_astrbot_root()
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
return (base / "data").resolve()
|
||||
|
||||
|
||||
def display_plugins(plugins, title=None, color=None):
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
|
||||
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
|
||||
click.echo("-" * 85)
|
||||
|
||||
for p in plugins:
|
||||
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||
click.echo(
|
||||
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||
f"{p['author']:<15} {desc:<30}"
|
||||
)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def new(name: str):
|
||||
"""创建新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(f"插件 {name} 已存在")
|
||||
|
||||
author = click.prompt("请输入插件作者", type=str)
|
||||
desc = click.prompt("请输入插件描述", type=str)
|
||||
version = click.prompt("请输入插件版本", type=str)
|
||||
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
|
||||
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
|
||||
repo = click.prompt("请输入插件仓库:", type=str)
|
||||
if not repo.startswith("http"):
|
||||
raise click.ClickException("仓库地址必须以 http 开头")
|
||||
|
||||
click.echo("下载插件模板...")
|
||||
get_git_repo(
|
||||
"https://github.com/Soulter/helloworld",
|
||||
plug_path,
|
||||
)
|
||||
|
||||
click.echo("重写插件信息...")
|
||||
# 重写 metadata.yaml
|
||||
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"name: {name}\n"
|
||||
f"desc: {desc}\n"
|
||||
f"version: {version}\n"
|
||||
f"author: {author}\n"
|
||||
f"repo: {repo}\n"
|
||||
)
|
||||
|
||||
# 重写 README.md
|
||||
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
|
||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||
|
||||
# 重写 main.py
|
||||
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
new_content = content.replace(
|
||||
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
|
||||
f'@register("{name}", "{author}", "{desc}", "{version}")',
|
||||
)
|
||||
|
||||
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
click.echo(f"插件 {name} 创建成功")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||
def list(all: bool):
|
||||
"""列出插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
# 未发布的插件
|
||||
not_published_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
|
||||
]
|
||||
if not_published_plugins:
|
||||
display_plugins(not_published_plugins, "未发布的插件", "red")
|
||||
|
||||
# 需要更新的插件
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
if need_update_plugins:
|
||||
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
|
||||
|
||||
# 已安装的插件
|
||||
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
|
||||
if installed_plugins:
|
||||
display_plugins(installed_plugins, "已安装的插件", "green")
|
||||
|
||||
# 未安装的插件
|
||||
not_installed_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
|
||||
]
|
||||
if not_installed_plugins and all:
|
||||
display_plugins(not_installed_plugins, "未安装的插件", "blue")
|
||||
|
||||
if (
|
||||
not any([not_published_plugins, need_update_plugins, installed_plugins])
|
||||
and not all
|
||||
):
|
||||
click.echo("未安装任何插件")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
@click.option("--proxy", help="代理服务器地址")
|
||||
def install(name: str, proxy: str | None):
|
||||
"""安装插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def remove(name: str):
|
||||
"""卸载插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||
|
||||
if not plugin or not plugin.get("local_path"):
|
||||
raise click.ClickException(f"插件 {name} 不存在或未安装")
|
||||
|
||||
plugin_path = plugin["local_path"]
|
||||
|
||||
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
|
||||
|
||||
try:
|
||||
shutil.rmtree(plugin_path)
|
||||
click.echo(f"插件 {name} 已卸载")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--proxy", help="Github代理地址")
|
||||
def update(name: str, proxy: str | None):
|
||||
"""更新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
if name:
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
else:
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
|
||||
if not need_update_plugins:
|
||||
click.echo("没有需要更新的插件")
|
||||
return
|
||||
|
||||
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
|
||||
for plugin in need_update_plugins:
|
||||
plugin_name = plugin["name"]
|
||||
click.echo(f"正在更新插件 {plugin_name}...")
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("query")
|
||||
def search(query: str):
|
||||
"""搜索插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
matched_plugins = [
|
||||
p
|
||||
for p in plugins
|
||||
if query.lower() in p["name"].lower()
|
||||
or query.lower() in p["desc"].lower()
|
||||
or query.lower() in p["author"].lower()
|
||||
]
|
||||
|
||||
if not matched_plugins:
|
||||
click.echo(f"未找到匹配 '{query}' 的插件")
|
||||
return
|
||||
|
||||
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
|
||||
63
astrbot/cli/commands/cmd_run.py
Normal file
63
astrbot/cli/commands/cmd_run.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
db = db_helper
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
|
||||
await core_lifecycle.start()
|
||||
|
||||
|
||||
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
|
||||
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
|
||||
@click.command()
|
||||
def run(reload: bool, port: str) -> None:
|
||||
"""运行 AstrBot"""
|
||||
try:
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
astrbot_root = get_astrbot_root()
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
sys.path.insert(0, str(astrbot_root))
|
||||
|
||||
if port:
|
||||
os.environ["DASHBOARD_PORT"] = port
|
||||
|
||||
if reload:
|
||||
click.echo("启用插件自动重载")
|
||||
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
|
||||
18
astrbot/cli/utils/__init__.py
Normal file
18
astrbot/cli/utils/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from .basic import (
|
||||
get_astrbot_root,
|
||||
check_astrbot_root,
|
||||
check_dashboard,
|
||||
)
|
||||
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
__all__ = [
|
||||
"get_astrbot_root",
|
||||
"check_astrbot_root",
|
||||
"check_dashboard",
|
||||
"get_git_repo",
|
||||
"manage_plugin",
|
||||
"build_plug_list",
|
||||
"VersionComparator",
|
||||
"PluginStatus",
|
||||
]
|
||||
67
astrbot/cli/utils/basic.py
Normal file
67
astrbot/cli/utils/basic.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""检查路径是否为 AstrBot 根目录"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
if not (path / ".astrbot").exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""获取Astrbot根目录路径"""
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("管理面板已是最新版本")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
230
astrbot/cli/utils/plugin.py
Normal file
230
astrbot/cli/utils/plugin.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import click
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class PluginStatus(str, Enum):
|
||||
INSTALLED = "已安装"
|
||||
NEED_UPDATE = "需更新"
|
||||
NOT_INSTALLED = "未安装"
|
||||
NOT_PUBLISHED = "未发布"
|
||||
|
||||
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
# 解析仓库信息
|
||||
repo_namespace = url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
# 尝试获取最新的 release
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
try:
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(release_url)
|
||||
resp.raise_for_status()
|
||||
releases = resp.json()
|
||||
|
||||
if releases:
|
||||
# 使用最新的 release
|
||||
download_url = releases[0]["zipball_url"]
|
||||
else:
|
||||
# 没有 release,使用默认分支
|
||||
click.echo(f"正在从默认分支下载 {author}/{repo}")
|
||||
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
except Exception as e:
|
||||
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
|
||||
download_url = url
|
||||
|
||||
# 应用代理
|
||||
if proxy:
|
||||
download_url = f"{proxy}/{download_url}"
|
||||
|
||||
# 下载并解压
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(download_url)
|
||||
if (
|
||||
resp.status_code == 404
|
||||
and "archive/refs/heads/master.zip" in download_url
|
||||
):
|
||||
alt_url = download_url.replace("master.zip", "main.zip")
|
||||
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||
resp = client.get(alt_url)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
zip_content = BytesIO(resp.content)
|
||||
with ZipFile(zip_content) as z:
|
||||
z.extractall(temp_dir)
|
||||
namelist = z.namelist()
|
||||
root_dir = Path(namelist[0]).parts[0] if namelist else ""
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.move(temp_dir / root_dir, target_path)
|
||||
finally:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
"""从 metadata.yaml 文件加载插件元数据
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||
"""
|
||||
yaml_path = plugin_dir / "metadata.yaml"
|
||||
if yaml_path.exists():
|
||||
try:
|
||||
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
|
||||
except Exception as e:
|
||||
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
|
||||
return {}
|
||||
|
||||
|
||||
def build_plug_list(plugins_dir: Path) -> list:
|
||||
"""构建插件列表,包含本地和在线插件信息
|
||||
|
||||
Args:
|
||||
plugins_dir (Path): 插件目录路径
|
||||
|
||||
Returns:
|
||||
list: 包含插件信息的字典列表
|
||||
"""
|
||||
# 获取本地插件信息
|
||||
result = []
|
||||
if plugins_dir.exists():
|
||||
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||
plugin_dir = plugins_dir / plugin_name
|
||||
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
resp = client.get("https://api.soulter.top/astrbot/plugins")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
# 与在线插件比对,更新状态
|
||||
online_plugin_names = {plugin["name"] for plugin in online_plugins}
|
||||
for local_plugin in result:
|
||||
if local_plugin["name"] in online_plugin_names:
|
||||
# 查找对应的在线插件
|
||||
online_plugin = next(
|
||||
p for p in online_plugins if p["name"] == local_plugin["name"]
|
||||
)
|
||||
if (
|
||||
VersionComparator.compare_version(
|
||||
local_plugin["version"], online_plugin["version"]
|
||||
)
|
||||
< 0
|
||||
):
|
||||
local_plugin["status"] = PluginStatus.NEED_UPDATE
|
||||
else:
|
||||
# 本地插件未在线上发布
|
||||
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
|
||||
|
||||
# 添加未安装的在线插件
|
||||
for online_plugin in online_plugins:
|
||||
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
|
||||
result.append(online_plugin)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def manage_plugin(
|
||||
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
||||
) -> None:
|
||||
"""安装或更新插件
|
||||
|
||||
Args:
|
||||
plugin (dict): 插件信息字典
|
||||
plugins_dir (Path): 插件目录
|
||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||
proxy (str, optional): 代理服务器地址
|
||||
"""
|
||||
plugin_name = plugin["name"]
|
||||
repo_url = plugin["repo"]
|
||||
|
||||
# 如果是更新且有本地路径,直接使用本地路径
|
||||
if is_update and plugin.get("local_path"):
|
||||
target_path = Path(plugin["local_path"])
|
||||
else:
|
||||
target_path = plugins_dir / plugin_name
|
||||
|
||||
backup_path = Path(f"{target_path}_backup") if is_update else None
|
||||
|
||||
# 检查插件是否存在
|
||||
if is_update and not target_path.exists():
|
||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||
|
||||
# 备份现有插件
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
if is_update:
|
||||
shutil.copytree(target_path, backup_path)
|
||||
|
||||
try:
|
||||
click.echo(
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
||||
)
|
||||
get_git_repo(repo_url, target_path, proxy)
|
||||
|
||||
# 更新成功,删除备份
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||
except Exception as e:
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
if is_update and backup_path.exists():
|
||||
shutil.move(backup_path, target_path)
|
||||
raise click.ClickException(
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
||||
)
|
||||
92
astrbot/cli/utils/version_comparator.py
Normal file
92
astrbot/cli/utils/version_comparator.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
拷贝自 astrbot.core.utils.version_comparator
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
@staticmethod
|
||||
def compare_version(v1: str, v2: str) -> int:
|
||||
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||
|
||||
参考: https://semver.org/lang/zh-CN/
|
||||
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.lower().replace("v", "")
|
||||
v2 = v2.lower().replace("v", "")
|
||||
|
||||
def split_version(version):
|
||||
match = re.match(
|
||||
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
|
||||
version,
|
||||
)
|
||||
if not match:
|
||||
return [], None
|
||||
major_minor_patch = match.group(1).split(".")
|
||||
prerelease = match.group(2)
|
||||
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||
parts = [int(x) for x in major_minor_patch]
|
||||
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||
return parts, prerelease
|
||||
|
||||
v1_parts, v1_prerelease = split_version(v1)
|
||||
v2_parts, v2_prerelease = split_version(v2)
|
||||
|
||||
# 比较数字部分
|
||||
length = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||
|
||||
for i in range(length):
|
||||
if v1_parts[i] > v2_parts[i]:
|
||||
return 1
|
||||
elif v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
|
||||
|
||||
if p1 is None and p2 is not None:
|
||||
return -1
|
||||
elif p1 is not None and p2 is None:
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, str):
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
return 0 # 数字部分和预发布标签都相同
|
||||
|
||||
@staticmethod
|
||||
def _split_prerelease(prerelease):
|
||||
if not prerelease:
|
||||
return None
|
||||
parts = prerelease.split(".")
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
result.append(int(part))
|
||||
else:
|
||||
result.append(part)
|
||||
return result
|
||||
@@ -7,27 +7,28 @@ from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs("data", exist_ok=True)
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = (
|
||||
SharedPreferences()
|
||||
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences()
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
astrbot_config.get("pip_install_arg", ""),
|
||||
astrbot_config.get("pypi_index_url", None),
|
||||
)
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -42,11 +43,10 @@ class AstrBotConfig(dict):
|
||||
"""不存在时载入默认配置"""
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith("/ufeff"): # remove BOM
|
||||
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
@@ -83,23 +83,61 @@ class AstrBotConfig(dict):
|
||||
return conf
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
"""检查配置完整性,如果有新的配置项则返回 True"""
|
||||
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
||||
has_new = False
|
||||
|
||||
# 创建一个新的有序字典以保持参考配置的顺序
|
||||
new_conf = {}
|
||||
|
||||
# 先按照参考配置的顺序添加配置项
|
||||
for key, value in refer_conf.items():
|
||||
if key not in conf:
|
||||
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}")
|
||||
# 配置项不存在,插入默认值
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
||||
conf[key] = value
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
if conf[key] is None:
|
||||
conf[key] = value
|
||||
# 配置项为 None,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
has_new |= self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
# 递归检查子配置项
|
||||
if not isinstance(conf[key], dict):
|
||||
# 类型不匹配,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
# 递归检查并同步顺序
|
||||
child_has_new = self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
new_conf[key] = conf[key]
|
||||
has_new |= child_has_new
|
||||
else:
|
||||
# 直接使用现有配置
|
||||
new_conf[key] = conf[key]
|
||||
|
||||
# 检查是否存在参考配置中没有的配置项
|
||||
for key in list(conf.keys()):
|
||||
if key not in refer_conf:
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
|
||||
has_new = True
|
||||
|
||||
# 顺序不一致也算作变更
|
||||
if list(conf.keys()) != list(new_conf.keys()):
|
||||
if path:
|
||||
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
|
||||
else:
|
||||
logger.info("检查到配置项顺序不一致,已重新排序")
|
||||
has_new = True
|
||||
|
||||
# 更新原始配置
|
||||
conf.clear()
|
||||
conf.update(new_conf)
|
||||
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: Dict = None):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
工作流程:
|
||||
@@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
@@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map
|
||||
class AstrBotCoreLifecycle:
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
@@ -54,7 +53,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
|
||||
# 初始化日志代理
|
||||
@@ -73,9 +72,6 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
@@ -87,7 +83,6 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
@@ -11,7 +11,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
|
||||
with open(
|
||||
os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8"
|
||||
) as f:
|
||||
sql = f.read()
|
||||
|
||||
# 初始化数据库
|
||||
|
||||
46
astrbot/core/db/vec_db/base.py
Normal file
46
astrbot/core/db/vec_db/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
similarity: float
|
||||
data: dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化向量数据库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
top_k (int): 返回的最相似文档的数量
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
删除指定文档。
|
||||
Args:
|
||||
doc_id (str): 要删除的文档 ID
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
Args:
|
||||
row (tuple): The row to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
"doc_id": row[1],
|
||||
"text": row[2],
|
||||
"metadata": row[3],
|
||||
"created_at": row[4],
|
||||
"updated_at": row[5],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
try:
|
||||
import faiss
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||
)
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
if path and os.path.exists(path):
|
||||
self.index = faiss.read_index(path)
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 要插入的向量
|
||||
id (int): 向量的ID
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
"""搜索最相似的向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 查询向量
|
||||
k (int): 返回的最相似向量的数量
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
Args:
|
||||
path (str): 保存索引的路径
|
||||
"""
|
||||
faiss.write_index(self.index, self.path)
|
||||
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
doc_id TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
|
||||
|
||||
CREATE INDEX idx_documents_user_id ON documents(user_id);
|
||||
CREATE INDEX idx_documents_group_id ON documents(group_id);
|
||||
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
A class to represent a vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
self.embedding_provider = embedding_provider
|
||||
self.document_storage = DocumentStorage(doc_store_path)
|
||||
self.embedding_storage = EmbeddingStorage(
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
embedding = await self.embedding_provider.get_embedding(query)
|
||||
scores, indices = await self.embedding_storage.search(
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
# TODO: rerank
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||
# NOTE: maybe the size is less than k.
|
||||
fetched_docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
pos = idx_pos.get(indice_idx)
|
||||
if pos is None:
|
||||
continue
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
return result_docs[:k]
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
删除一条文档
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
68
astrbot/core/file_token_service.py
Normal file
68
astrbot/core/file_token_service.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
||||
"""向令牌服务注册一个文件。
|
||||
|
||||
Args:
|
||||
file_path(str): 文件路径
|
||||
timeout(float): 超时时间,单位秒(可选)
|
||||
|
||||
Returns:
|
||||
str: 一个单次令牌
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
||||
self.staged_files[file_token] = (file_path, expire_time)
|
||||
return file_token
|
||||
|
||||
async def handle_file(self, file_token: str) -> str:
|
||||
"""根据令牌获取文件路径,使用后令牌失效。
|
||||
|
||||
Args:
|
||||
file_token(str): 注册时返回的令牌
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
|
||||
Raises:
|
||||
KeyError: 当令牌不存在或已过期时抛出
|
||||
FileNotFoundError: 当文件本身已被删除时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if file_token not in self.staged_files:
|
||||
raise KeyError(f"无效或过期的文件 token: {file_token}")
|
||||
|
||||
file_path, _ = self.staged_files.pop(file_token)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
return file_path
|
||||
@@ -26,13 +26,14 @@ class InitialLoader:
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||
return
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
|
||||
@@ -22,16 +22,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import typing as T
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.v1 import BaseModel
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file
|
||||
|
||||
from astrbot.core import astrbot_config, file_token_service, logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
@@ -99,6 +102,10 @@ class BaseMessageComponent(BaseModel):
|
||||
data[k] = v
|
||||
return {"type": self.type.lower(), "data": data}
|
||||
|
||||
async def to_dict(self) -> dict:
|
||||
# 默认情况下,回退到旧的同步 toDict()
|
||||
return self.toDict()
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
@@ -115,6 +122,9 @@ class Plain(BaseMessageComponent):
|
||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
)
|
||||
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
@@ -167,7 +177,8 @@ class Record(BaseMessageComponent):
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
@@ -198,6 +209,29 @@ class Record(BaseMessageComponent):
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将语音注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Video(BaseMessageComponent):
|
||||
type: ComponentType = "Video"
|
||||
@@ -208,9 +242,6 @@ class Video(BaseMessageComponent):
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "c" and _[k] not in [2, 3]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -223,6 +254,70 @@ class Video(BaseMessageComponent):
|
||||
return Video(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
async def convert_to_file_path(self) -> str:
|
||||
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。
|
||||
|
||||
Returns:
|
||||
str: 视频的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.file
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url and url.startswith("http"):
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
else:
|
||||
raise Exception(f"download failed: {url}")
|
||||
elif os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = self.file
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated video file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
@@ -232,6 +327,12 @@ class At(BaseMessageComponent):
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
def toDict(self):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {"qq": str(self.qq)},
|
||||
}
|
||||
|
||||
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
@@ -371,7 +472,8 @@ class Image(BaseMessageComponent):
|
||||
elif url and url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
@@ -403,6 +505,29 @@ class Image(BaseMessageComponent):
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将图片注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
@@ -462,28 +587,48 @@ class Node(BaseMessageComponent):
|
||||
type: ComponentType = "Node"
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[int] = 0 # qq号
|
||||
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
content: T.Optional[list[BaseMessageComponent]] = []
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0 # 忽略
|
||||
|
||||
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||
if isinstance(content, list):
|
||||
_content = None
|
||||
if all(isinstance(item, Node) for item in content):
|
||||
_content = [node.toDict() for node in content]
|
||||
else:
|
||||
_content = ""
|
||||
for chain in content:
|
||||
_content += chain.toString()
|
||||
content = _content
|
||||
elif isinstance(content, Node):
|
||||
content = content.toDict()
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
if isinstance(content, Node):
|
||||
# back
|
||||
content = [content]
|
||||
super().__init__(content=content, **_)
|
||||
|
||||
def toString(self):
|
||||
# logger.warn("Protocol: node doesn't support stringify")
|
||||
return ""
|
||||
async def to_dict(self):
|
||||
data_content = []
|
||||
for comp in self.content:
|
||||
if isinstance(comp, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await comp.convert_to_base64()
|
||||
data_content.append(
|
||||
{
|
||||
"type": comp.type.lower(),
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, (Node, Nodes)):
|
||||
# For Node segments, we recursively convert them to dict
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
else:
|
||||
d = comp.toDict()
|
||||
data_content.append(d)
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": str(self.uin),
|
||||
"nickname": self.name,
|
||||
"content": data_content,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
@@ -494,7 +639,22 @@ class Nodes(BaseMessageComponent):
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
return {"messages": [node.toDict() for node in self.nodes]}
|
||||
"""Deprecated. Use to_dict instead"""
|
||||
ret = {
|
||||
"messages": [],
|
||||
}
|
||||
for node in self.nodes:
|
||||
d = node.toDict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
d = await node.to_dict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
@@ -559,12 +719,12 @@ class File(BaseMessageComponent):
|
||||
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
_file: T.Optional[str] = "" # 本地路径
|
||||
file_: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
_downloaded: bool = False # 是否已经下载
|
||||
|
||||
def __init__(self, name: str = "", file: str = "", url: str = ""):
|
||||
super().__init__(name=name, _file=file, url=url)
|
||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||
"""文件消息段。"""
|
||||
super().__init__(name=name, file_=file, url=url)
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
@@ -574,23 +734,27 @@ class File(BaseMessageComponent):
|
||||
Returns:
|
||||
str: 文件路径
|
||||
"""
|
||||
if self._file and os.path.exists(self._file):
|
||||
return self._file
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url and not self._downloaded:
|
||||
if self.url:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
"不可以在异步上下文中同步等待下载! 请使用 await get_file() 代替"
|
||||
(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
||||
)
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self._file and os.path.exists(self._file):
|
||||
return self._file
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
@@ -607,38 +771,79 @@ class File(BaseMessageComponent):
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
self.url = value
|
||||
else:
|
||||
self._file = value
|
||||
self.file_ = value
|
||||
|
||||
async def get_file(self) -> str:
|
||||
"""
|
||||
异步获取文件
|
||||
To 插件开发者: 请注意在使用后清理下载的文件, 以免占用过多空间
|
||||
async def get_file(self, allow_return_url: bool = False) -> str:
|
||||
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
|
||||
|
||||
Args:
|
||||
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
|
||||
注意,如果为 True,也可能返回文件路径。
|
||||
Returns:
|
||||
str: 文件路径
|
||||
str: 文件路径或者 http 下载链接
|
||||
"""
|
||||
if self._file and os.path.exists(self._file):
|
||||
return self._file
|
||||
if allow_return_url and self.url:
|
||||
return self.url
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return self._file
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if self._downloaded:
|
||||
return
|
||||
|
||||
os.makedirs("data/download", exist_ok=True)
|
||||
filename = self.name or f"{uuid.uuid4().hex}"
|
||||
file_path = f"data/download/{filename}"
|
||||
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
self._file = file_path
|
||||
self._downloaded = True
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将文件注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.get_file()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = await self.get_file(allow_return_url=True)
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {
|
||||
"name": self.name,
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
|
||||
@@ -43,31 +43,31 @@ class PreProcessStage(Stage):
|
||||
# STT
|
||||
if self.stt_settings.get("enable", False):
|
||||
# TODO: 独立
|
||||
stt_provider = (
|
||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
)
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
|
||||
@@ -33,6 +33,7 @@ from mcp.types import (
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -67,7 +68,11 @@ class LLMRequestSubStage(Stage):
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
umo = event.unified_msg_origin
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
@@ -283,7 +288,66 @@ class LLMRequestSubStage(Stage):
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
yield
|
||||
|
||||
if event.get_platform_name() == "webchat":
|
||||
# 异步处理 WebChat 特殊情况
|
||||
asyncio.create_task(self._handle_webchat(event, req))
|
||||
|
||||
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
if conversation and not req.conversation.title:
|
||||
messages = json.loads(conversation.history)
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
# if len(latest_pair) > 1:
|
||||
# cleaned_text += (
|
||||
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
|
||||
# )
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await provider.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
prompt=(
|
||||
f"Please summarize the following query of user:\n"
|
||||
f"{cleaned_text}\n"
|
||||
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
|
||||
"You must use the same language as the user."
|
||||
"If you think the dialog is too short to summarize, only output a special mark: `None`"
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
logger.debug(
|
||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
||||
)
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "None" == title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
event.unified_msg_origin, title=title
|
||||
)
|
||||
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
|
||||
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
|
||||
# TODO: 优化 WebChat 适配器的对话管理
|
||||
if event.session_id:
|
||||
username, cid = event.session_id.split("!")[1:3]
|
||||
db_helper = self.ctx.plugin_manager.context._db
|
||||
db_helper.update_conversation_title(
|
||||
user_id=username,
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
web_chat_back_queue.put_nowait(
|
||||
{
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_llm_response(
|
||||
self,
|
||||
|
||||
@@ -58,33 +58,30 @@ class RateLimitStage(Stage):
|
||||
now = datetime.now()
|
||||
|
||||
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
# 检查并处理限流,可能需要多次检查直到满足条件
|
||||
while True:
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
if len(timestamps) < self.rate_limit_count:
|
||||
timestamps.append(now)
|
||||
break
|
||||
else:
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
||||
|
||||
if len(timestamps) >= self.rate_limit_count:
|
||||
# 达到限流阈值,计算下一个窗口的时间
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(
|
||||
timestamps, now + timedelta(seconds=stall_duration)
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
now = datetime.now()
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
def _remove_expired_timestamps(
|
||||
self, timestamps: Deque[datetime], now: datetime
|
||||
|
||||
@@ -26,33 +26,13 @@ class RespondStage(Stage):
|
||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||
Comp.AtAll: lambda comp: True, # @所有人
|
||||
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||
Comp.Music: lambda comp: bool(comp._type)
|
||||
and bool(comp.url)
|
||||
and bool(comp.audio), # 音乐
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||
Comp.Node: lambda comp: bool(comp.name)
|
||||
and comp.uin != 0
|
||||
and bool(comp.content), # 一个转发节点
|
||||
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
@@ -129,8 +109,6 @@ class RespondStage(Stage):
|
||||
if comp_type in self._component_validators:
|
||||
if self._component_validators[comp_type](comp):
|
||||
return False
|
||||
else:
|
||||
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
@@ -175,6 +153,11 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||
non_record_comps = [
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
]
|
||||
|
||||
if self.enable_seg and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
@@ -192,21 +175,39 @@ class RespondStage(Stage):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
try:
|
||||
await event.send(result)
|
||||
await event.send(MessageChain(non_record_comps))
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -168,30 +169,55 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
):
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(
|
||||
Record(file=audio_path, url=audio_path)
|
||||
)
|
||||
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
@@ -225,6 +251,14 @@ class ResultDecorateStage(Stage):
|
||||
if url:
|
||||
if url.startswith("http"):
|
||||
result.chain = [Image.fromURL(url)]
|
||||
elif (
|
||||
self.ctx.astrbot_config["t2i_use_file_service"]
|
||||
and self.ctx.astrbot_config["callback_api_base"]
|
||||
):
|
||||
token = await file_token_service.register_file(url)
|
||||
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
result.chain = [Image.fromURL(url)]
|
||||
else:
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
@@ -39,6 +39,9 @@ class WakingCheckStage(Stage):
|
||||
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_bot_self_message", False
|
||||
)
|
||||
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_at_all", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -77,11 +80,18 @@ class WakingCheckStage(Stage):
|
||||
event.message_str = event.message_str[len(wake_prefix) :].strip()
|
||||
break
|
||||
if not is_wake:
|
||||
# 检查是否有 at 消息
|
||||
# 检查是否有at消息 / at全体成员消息 / 引用了bot的消息
|
||||
for message in messages:
|
||||
if isinstance(message, At) and (
|
||||
str(message.qq) == str(event.get_self_id())
|
||||
or str(message.qq) == "all"
|
||||
if (
|
||||
(
|
||||
isinstance(message, At)
|
||||
and (str(message.qq) == str(event.get_self_id()))
|
||||
)
|
||||
or (isinstance(message, AtAll) and not self.ignore_at_all)
|
||||
or (
|
||||
isinstance(message, Reply)
|
||||
and str(message.sender_id) == str(event.get_self_id())
|
||||
)
|
||||
):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
@@ -137,7 +147,7 @@ class WakingCheckStage(Stage):
|
||||
if self.no_permission_reply:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
||||
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
|
||||
@@ -62,6 +62,10 @@ class PlatformManager:
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
)
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
case "dingtalk":
|
||||
@@ -73,7 +77,15 @@ class PlatformManager:
|
||||
case "wecom":
|
||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa
|
||||
)
|
||||
case "discord":
|
||||
from .sources.discord.discord_platform_adapter import (
|
||||
DiscordPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "slack":
|
||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||
|
||||
@@ -3,7 +3,16 @@ import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Node,
|
||||
Nodes,
|
||||
Plain,
|
||||
Record,
|
||||
Video,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
|
||||
|
||||
@@ -14,44 +23,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
|
||||
"""修复部分字段"""
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
return {
|
||||
"type": segment.type.lower(),
|
||||
"data": {
|
||||
"file": f"base64://{bs64}",
|
||||
},
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
elif isinstance(segment, Video):
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
else:
|
||||
# For other segments, we simply convert them to a dict by calling toDict
|
||||
return segment.toDict()
|
||||
|
||||
@staticmethod
|
||||
async def _parse_onebot_json(message_chain: MessageChain):
|
||||
"""解析成 OneBot json 格式"""
|
||||
ret = []
|
||||
for segment in message_chain.chain:
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d["type"] = "text"
|
||||
d["data"]["text"] = segment.text.strip()
|
||||
# 如果是空文本或者只带换行符的文本,不发送
|
||||
if not d["data"]["text"]:
|
||||
if not segment.text.strip():
|
||||
continue
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
d["data"] = {
|
||||
"file": f"base64://{bs64}",
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d["data"] = {
|
||||
"qq": str(segment.qq) # 转换为字符串
|
||||
}
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
|
||||
if not ret:
|
||||
return
|
||||
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 转发消息不能和普通消息混在一起发送
|
||||
send_one_by_one = True
|
||||
break
|
||||
|
||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||
send_one_by_one = any(
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
||||
)
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
@@ -61,7 +72,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = seg.toDict()
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
@@ -70,6 +82,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
await self.bot.call_action(
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
elif isinstance(seg, File):
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
[d],
|
||||
)
|
||||
else:
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
@@ -79,6 +97,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if not ret:
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
@@ -103,6 +103,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if event["post_type"] == "message":
|
||||
abm = await self._convert_handle_message_event(event)
|
||||
if abm.sender.user_id == "2854196310":
|
||||
# 屏蔽 QQ 管家的消息
|
||||
return
|
||||
elif event["post_type"] == "notice":
|
||||
abm = await self._convert_handle_notice_event(event)
|
||||
elif event["post_type"] == "request":
|
||||
@@ -217,9 +220,12 @@ class AiocqhttpAdapter(Platform):
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
a = None
|
||||
if t == "text":
|
||||
# 合并相邻文本段
|
||||
message_str = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
a = ComponentTypes[t](text=message_str) # noqa: F405
|
||||
current_text = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
if not current_text:
|
||||
# 如果文本段为空,则跳过
|
||||
continue
|
||||
message_str += current_text
|
||||
a = ComponentTypes[t](text=current_text) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
elif t == "file":
|
||||
@@ -287,6 +293,42 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
elif t == "at":
|
||||
first_at_self_processed = False
|
||||
|
||||
for m in m_group:
|
||||
try:
|
||||
if m["data"]["qq"] == "all":
|
||||
abm.message.append(At(qq="all", name="全体成员"))
|
||||
continue
|
||||
|
||||
at_info = await self.bot.call_action(
|
||||
action="get_stranger_info",
|
||||
user_id=int(m["data"]["qq"]),
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("nick", "")
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
At(
|
||||
qq=m["data"]["qq"],
|
||||
name=nickname,
|
||||
)
|
||||
)
|
||||
|
||||
if is_at_self and not first_at_self_processed:
|
||||
# 第一个@是机器人,不添加到message_str
|
||||
first_at_self_processed = True
|
||||
else:
|
||||
# 非第一个@机器人或@其他用户,添加到message_str
|
||||
message_str += f" @{nickname} "
|
||||
else:
|
||||
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
else:
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -19,6 +20,7 @@ from ...register import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from dingtalk_stream import AckMessage
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
@@ -152,7 +154,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
"downloadCode": download_code,
|
||||
"robotCode": robot_code,
|
||||
}
|
||||
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||
|
||||
@@ -32,31 +32,31 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
logger.warning(
|
||||
"dingtalk only support url image, not: " + segment.file
|
||||
)
|
||||
continue
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
markdown_str += f"\n\n"
|
||||
elif segment.file and segment.file.startswith("base64://"):
|
||||
logger.warning("dingtalk only support url image, not base64")
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
"dingtalk only support url image, not: " + segment.file
|
||||
)
|
||||
continue
|
||||
|
||||
ret = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
try:
|
||||
if not segment.file:
|
||||
logger.warning("钉钉图片 segment 缺少 file 字段,跳过")
|
||||
continue
|
||||
if segment.file.startswith(("http://", "https://")):
|
||||
image_url = segment.file
|
||||
else:
|
||||
image_url = await segment.register_to_file_service()
|
||||
|
||||
markdown_str = f"\n\n"
|
||||
|
||||
ret = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"钉钉图片处理失败: {e}")
|
||||
logger.warning(f"跳过图片发送: {image_path}")
|
||||
continue
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
126
astrbot/core/platform/sources/discord/client.py
Normal file
126
astrbot/core/platform/sources/discord/client.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import discord
|
||||
from astrbot import logger
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# Discord Bot客户端
|
||||
class DiscordBotClient(discord.Bot):
|
||||
"""Discord客户端封装"""
|
||||
|
||||
def __init__(self, token: str, proxy: str = None):
|
||||
self.token = token
|
||||
self.proxy = proxy
|
||||
|
||||
# 设置Intent权限,遵循权限最小化原则
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True # 订阅消息内容事件 (Privileged)
|
||||
intents.members = True # 订阅成员事件 (Privileged)
|
||||
|
||||
# 初始化Bot
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received = None
|
||||
self.on_ready_once_callback = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
@override
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
if self.on_ready_once_callback and not self._ready_once_fired:
|
||||
self._ready_once_fired = True
|
||||
try:
|
||||
await self.on_ready_once_callback()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
"bot_id": str(self.user.id),
|
||||
"content": message.content,
|
||||
"username": message.author.display_name,
|
||||
"userid": str(message.author.id),
|
||||
"message_id": str(message.id),
|
||||
"channel_id": str(message.channel.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"type": "message",
|
||||
"is_mentioned": is_mentioned,
|
||||
"clean_content": message.clean_content,
|
||||
}
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
"content": self._extract_interaction_content(interaction),
|
||||
"username": interaction.user.display_name,
|
||||
"userid": str(interaction.user.id),
|
||||
"message_id": str(interaction.id),
|
||||
"channel_id": str(interaction.channel_id)
|
||||
if interaction.channel_id
|
||||
else None,
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
|
||||
)
|
||||
|
||||
if self.on_message_received:
|
||||
message_data = self._create_message_data(message)
|
||||
await self.on_message_received(message_data)
|
||||
|
||||
|
||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||
"""从交互中提取内容"""
|
||||
interaction_type = interaction.type
|
||||
interaction_data = getattr(interaction, "data", {})
|
||||
|
||||
if not interaction_data:
|
||||
return ""
|
||||
|
||||
if interaction_type == discord.InteractionType.application_command:
|
||||
command_name = interaction_data.get("name", "")
|
||||
if options := interaction_data.get("options", []):
|
||||
params = " ".join(
|
||||
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
|
||||
)
|
||||
return f"/{command_name} {params}"
|
||||
return f"/{command_name}"
|
||||
|
||||
elif interaction_type == discord.InteractionType.component:
|
||||
custom_id = interaction_data.get("custom_id", "")
|
||||
component_type = interaction_data.get("component_type", "")
|
||||
return f"component:{custom_id}:{component_type}"
|
||||
|
||||
return str(interaction_data)
|
||||
|
||||
async def start_polling(self):
|
||||
"""开始轮询消息,这是个阻塞方法"""
|
||||
await self.start(self.token)
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
if not self.is_closed():
|
||||
await super().close()
|
||||
133
astrbot/core/platform/sources/discord/components.py
Normal file
133
astrbot/core/platform/sources/discord/components.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import discord
|
||||
from typing import List
|
||||
from astrbot.api.message_components import BaseMessageComponent
|
||||
|
||||
|
||||
# Discord专用组件
|
||||
class DiscordEmbed(BaseMessageComponent):
|
||||
"""Discord Embed消息组件"""
|
||||
|
||||
type: str = "discord_embed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
color: int = None,
|
||||
url: str = None,
|
||||
thumbnail: str = None,
|
||||
image: str = None,
|
||||
footer: str = None,
|
||||
fields: List[dict] = None,
|
||||
):
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.color = color
|
||||
self.url = url
|
||||
self.thumbnail = thumbnail
|
||||
self.image = image
|
||||
self.footer = footer
|
||||
self.fields = fields or []
|
||||
|
||||
def to_discord_embed(self) -> discord.Embed:
|
||||
"""转换为Discord Embed对象"""
|
||||
embed = discord.Embed()
|
||||
|
||||
if self.title:
|
||||
embed.title = self.title
|
||||
if self.description:
|
||||
embed.description = self.description
|
||||
if self.color:
|
||||
embed.color = self.color
|
||||
if self.url:
|
||||
embed.url = self.url
|
||||
if self.thumbnail:
|
||||
embed.set_thumbnail(url=self.thumbnail)
|
||||
if self.image:
|
||||
embed.set_image(url=self.image)
|
||||
if self.footer:
|
||||
embed.set_footer(text=self.footer)
|
||||
|
||||
for field in self.fields:
|
||||
embed.add_field(
|
||||
name=field.get("name", ""),
|
||||
value=field.get("value", ""),
|
||||
inline=field.get("inline", False),
|
||||
)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class DiscordButton(BaseMessageComponent):
|
||||
"""Discord按钮组件"""
|
||||
|
||||
type: str = "discord_button"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
custom_id: str = None,
|
||||
style: str = "primary",
|
||||
emoji: str = None,
|
||||
url: str = None,
|
||||
disabled: bool = False,
|
||||
):
|
||||
self.label = label
|
||||
self.custom_id = custom_id
|
||||
self.style = style
|
||||
self.emoji = emoji
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
|
||||
class DiscordReference(BaseMessageComponent):
|
||||
"""Discord引用组件"""
|
||||
type: str = "discord_reference"
|
||||
def __init__(self, message_id: str, channel_id: str):
|
||||
self.message_id = message_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
|
||||
class DiscordView(BaseMessageComponent):
|
||||
"""Discord视图组件,包含按钮和选择菜单"""
|
||||
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(
|
||||
self, components: List[BaseMessageComponent] = None, timeout: float = None
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def to_discord_view(self) -> discord.ui.View:
|
||||
"""转换为Discord View对象"""
|
||||
view = discord.ui.View(timeout=self.timeout)
|
||||
|
||||
for component in self.components:
|
||||
if isinstance(component, DiscordButton):
|
||||
button_style = getattr(
|
||||
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
|
||||
)
|
||||
|
||||
if component.url:
|
||||
# URL按钮
|
||||
button = discord.ui.Button(
|
||||
label=component.label,
|
||||
style=discord.ButtonStyle.link,
|
||||
url=component.url,
|
||||
emoji=component.emoji,
|
||||
disabled=component.disabled,
|
||||
)
|
||||
else:
|
||||
# 普通按钮
|
||||
button = discord.ui.Button(
|
||||
label=component.label,
|
||||
style=button_style,
|
||||
custom_id=component.custom_id,
|
||||
emoji=component.emoji,
|
||||
disabled=component.disabled,
|
||||
)
|
||||
|
||||
view.add_item(button)
|
||||
|
||||
return view
|
||||
@@ -0,0 +1,412 @@
|
||||
import asyncio
|
||||
import discord
|
||||
import sys
|
||||
import re
|
||||
from discord.abc import Messageable
|
||||
from discord.channel import DMChannel
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, File
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from .client import DiscordBotClient
|
||||
from .discord_platform_event import DiscordPlatformEvent
|
||||
|
||||
from typing import Any, Tuple
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 注册平台适配器
|
||||
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
|
||||
class DiscordPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.registered_handlers = []
|
||||
# 指令注册相关
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
|
||||
self.activity_name = self.config.get("discord_activity_name", None)
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
# 创建一个 message_obj 以便在 event 中使用
|
||||
message_obj = AstrBotMessage()
|
||||
if "_" in session.session_id:
|
||||
session.session_id = session.session_id.split("_")[1]
|
||||
channel_id_str = session.session_id
|
||||
channel = None
|
||||
try:
|
||||
channel_id = int(channel_id_str)
|
||||
channel = self.client.get_channel(channel_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"[Discord] Invalid channel ID format: {channel_id_str}")
|
||||
|
||||
if channel:
|
||||
message_obj.type = self._get_message_type(channel)
|
||||
message_obj.group_id = self._get_channel_id(channel)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[Discord] Can't get channel info for {channel_id_str}, will guess message type."
|
||||
)
|
||||
message_obj.type = MessageType.GROUP_MESSAGE
|
||||
message_obj.group_id = session.session_id
|
||||
|
||||
message_obj.message_str = message_chain.get_plain_text()
|
||||
message_obj.sender = MessageMember(
|
||||
user_id=str(self.client_self_id), nickname=self.client.user.display_name
|
||||
)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain
|
||||
|
||||
# 创建临时事件对象来发送消息
|
||||
temp_event = DiscordPlatformEvent(
|
||||
message_str=message_chain.get_plain_text(),
|
||||
message_obj=message_obj,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
await temp_event.send(message_chain)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""返回平台元数据"""
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
default_config_tmpl=self.config,
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
"""主要运行逻辑"""
|
||||
|
||||
# 初始化回调函数
|
||||
async def on_received(message_data):
|
||||
logger.debug(f"[Discord] 收到消息: {message_data}")
|
||||
if self.client_self_id is None:
|
||||
self.client_self_id = message_data.get("bot_id")
|
||||
abm = await self.convert_message(data=message_data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
# 初始化 Discord 客户端
|
||||
token = str(self.config.get("discord_token"))
|
||||
if not token:
|
||||
logger.error("[Discord] Bot Token 未配置。请在配置文件中正确设置 token。")
|
||||
return
|
||||
|
||||
proxy = self.config.get("discord_proxy") or None
|
||||
self.client = DiscordBotClient(token, proxy)
|
||||
self.client.on_message_received = on_received
|
||||
|
||||
async def callback():
|
||||
if self.enable_command_register:
|
||||
await self._collect_and_register_commands()
|
||||
if self.activity_name:
|
||||
await self.client.change_presence(
|
||||
status=discord.Status.online,
|
||||
activity=discord.CustomActivity(name=self.activity_name),
|
||||
)
|
||||
|
||||
self.client.on_ready_once_callback = callback
|
||||
|
||||
try:
|
||||
await self.client.start_polling()
|
||||
except discord.errors.LoginFailure:
|
||||
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
|
||||
except discord.errors.ConnectionClosed:
|
||||
logger.warning("[Discord] 与 Discord 的连接已关闭。")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True)
|
||||
|
||||
def _get_message_type(
|
||||
self, channel: Messageable, guild_id: int | None = None
|
||||
) -> MessageType:
|
||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||
if guild_id is not None:
|
||||
return MessageType.GROUP_MESSAGE
|
||||
if isinstance(channel, DMChannel) or getattr(channel, "guild", None) is None:
|
||||
return MessageType.FRIEND_MESSAGE
|
||||
return MessageType.GROUP_MESSAGE
|
||||
|
||||
def _get_channel_id(self, channel: Messageable) -> str:
|
||||
"""根据 channel 对象获取ID"""
|
||||
return str(getattr(channel, "id", None))
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
is_mentioned = data.get("is_mentioned", False)
|
||||
|
||||
content = message.content
|
||||
|
||||
# 如果机器人被@,移除@部分
|
||||
if (
|
||||
is_mentioned
|
||||
and self.client
|
||||
and self.client.user
|
||||
and self.client.user in message.mentions
|
||||
):
|
||||
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
|
||||
mention_str = f"<@{self.client.user.id}>"
|
||||
mention_str_nickname = (
|
||||
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
|
||||
)
|
||||
|
||||
if content.startswith(mention_str):
|
||||
content = content[len(mention_str) :].lstrip()
|
||||
elif content.startswith(mention_str_nickname):
|
||||
content = content[len(mention_str_nickname) :].lstrip()
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
abm.type = self._get_message_type(message.channel)
|
||||
abm.group_id = self._get_channel_id(message.channel)
|
||||
|
||||
abm.message_str = content
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(message.author.id), nickname=message.author.display_name
|
||||
)
|
||||
|
||||
message_chain = []
|
||||
if abm.message_str:
|
||||
message_chain.append(Plain(text=abm.message_str))
|
||||
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith(
|
||||
"image/"
|
||||
):
|
||||
message_chain.append(
|
||||
Image(file=attachment.url, filename=attachment.filename)
|
||||
)
|
||||
else:
|
||||
message_chain.append(
|
||||
File(name=attachment.filename, url=attachment.url)
|
||||
)
|
||||
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
|
||||
async def convert_message(self, data: dict) -> AstrBotMessage:
|
||||
"""将平台消息转换成 AstrBotMessage"""
|
||||
# 由于 on_interaction 已被禁用,我们只处理普通消息
|
||||
return self._convert_message_to_abm(data)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
|
||||
"""处理消息"""
|
||||
message_event = DiscordPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
interaction_followup_webhook=followup_webhook,
|
||||
)
|
||||
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 检查是否被@
|
||||
is_mention = (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
and self.client.user in message.raw_message.mentions
|
||||
)
|
||||
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
@override
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("[Discord] 正在终止适配器...")
|
||||
|
||||
# 清理指令
|
||||
if self.enable_command_register and self.client:
|
||||
logger.info("[Discord] 正在清理已注册的斜杠指令...")
|
||||
try:
|
||||
# 传入空的列表来清除所有全局指令
|
||||
# 如果指定了 guild_id,则只清除该服务器的指令
|
||||
await self.client.sync_commands(
|
||||
commands=[], guild_ids=[self.guild_id] if self.guild_id else None
|
||||
)
|
||||
logger.info("[Discord] 指令清理完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
|
||||
|
||||
if self.client and hasattr(self.client, "close"):
|
||||
await self.client.close()
|
||||
logger.info("[Discord] 适配器已终止。")
|
||||
|
||||
def register_handler(self, handler_info):
|
||||
"""注册处理器信息"""
|
||||
self.registered_handlers.append(handler_info)
|
||||
|
||||
async def _collect_and_register_commands(self):
|
||||
"""收集所有指令并注册到Discord"""
|
||||
logger.info("[Discord] 开始收集并注册斜杠指令...")
|
||||
registered_commands = []
|
||||
|
||||
for handler_md in star_handlers_registry:
|
||||
if not star_map[handler_md.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_md.event_filters:
|
||||
cmd_info = self._extract_command_info(event_filter, handler_md)
|
||||
if not cmd_info:
|
||||
continue
|
||||
|
||||
cmd_name, description, cmd_filter_instance = cmd_info
|
||||
|
||||
# 创建动态回调
|
||||
callback = self._create_dynamic_callback(cmd_name)
|
||||
|
||||
# 创建一个通用的参数选项来接收所有文本输入
|
||||
options = [
|
||||
discord.Option(
|
||||
name="params",
|
||||
description="指令的所有参数",
|
||||
type=discord.SlashCommandOptionType.string,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
# 创建SlashCommand
|
||||
slash_command = discord.SlashCommand(
|
||||
name=cmd_name,
|
||||
description=description,
|
||||
func=callback,
|
||||
options=options,
|
||||
guild_ids=[self.guild_id] if self.guild_id else None,
|
||||
)
|
||||
self.client.add_application_command(slash_command)
|
||||
registered_commands.append(cmd_name)
|
||||
|
||||
if registered_commands:
|
||||
logger.info(
|
||||
f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}"
|
||||
)
|
||||
else:
|
||||
logger.info("[Discord] 没有发现可注册的指令。")
|
||||
|
||||
# 使用 Pycord 的方法同步指令
|
||||
# 注意:这可能需要一些时间,并且有频率限制
|
||||
await self.client.sync_commands()
|
||||
logger.info("[Discord] 指令同步完成。")
|
||||
|
||||
def _create_dynamic_callback(self, cmd_name: str):
|
||||
"""为每个指令动态创建一个异步回调函数"""
|
||||
|
||||
async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None):
|
||||
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
|
||||
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
|
||||
logger.debug(f"[Discord] 回调函数参数: {ctx}")
|
||||
logger.debug(f"[Discord] 回调函数参数: {params}")
|
||||
message_str_for_filter = cmd_name
|
||||
if params:
|
||||
message_str_for_filter += f" {params}"
|
||||
|
||||
logger.debug(
|
||||
f"[Discord] 斜杠指令 '{cmd_name}' 被触发。 "
|
||||
f"原始参数: '{params}'. "
|
||||
f"构建的指令字符串: '{message_str_for_filter}'"
|
||||
)
|
||||
|
||||
# 尝试立即响应,防止超时
|
||||
followup_webhook = None
|
||||
try:
|
||||
await ctx.defer()
|
||||
followup_webhook = ctx.followup
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}")
|
||||
|
||||
# 2. 构建 AstrBotMessage
|
||||
abm = AstrBotMessage()
|
||||
abm.type = self._get_message_type(ctx.channel, ctx.guild_id)
|
||||
abm.group_id = self._get_channel_id(ctx.channel)
|
||||
abm.message_str = message_str_for_filter
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(ctx.author.id), nickname=ctx.author.display_name
|
||||
)
|
||||
abm.message = [Plain(text=message_str_for_filter)]
|
||||
abm.raw_message = ctx.interaction
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(ctx.channel_id)
|
||||
abm.message_id = str(ctx.interaction.id)
|
||||
|
||||
# 3. 将消息和 webhook 分别交给 handle_msg 处理
|
||||
await self.handle_msg(abm, followup_webhook)
|
||||
|
||||
return dynamic_callback
|
||||
|
||||
@staticmethod
|
||||
def _extract_command_info(
|
||||
event_filter: Any, handler_metadata: StarHandlerMetadata
|
||||
) -> Tuple[str, str, CommandFilter] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
# is_group = False
|
||||
cmd_filter_instance = None
|
||||
|
||||
if isinstance(event_filter, CommandFilter):
|
||||
# 暂不支持子指令注册为斜杠指令
|
||||
if (
|
||||
event_filter.parent_command_names
|
||||
and event_filter.parent_command_names != [""]
|
||||
):
|
||||
return None
|
||||
cmd_name = event_filter.command_name
|
||||
cmd_filter_instance = event_filter
|
||||
|
||||
elif isinstance(event_filter, CommandGroupFilter):
|
||||
# 暂不支持指令组直接注册为斜杠指令,因为它们没有 handle 方法
|
||||
return None
|
||||
|
||||
if not cmd_name:
|
||||
return None
|
||||
|
||||
# Discord 斜杠指令名称规范
|
||||
if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name):
|
||||
logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}")
|
||||
return None
|
||||
|
||||
description = handler_metadata.desc or f"指令: {cmd_name}"
|
||||
if len(description) > 100:
|
||||
description = f"{description[:97]}..."
|
||||
|
||||
return cmd_name, description, cmd_filter_instance
|
||||
291
astrbot/core/platform/sources/discord/discord_platform_event.py
Normal file
291
astrbot/core/platform/sources/discord/discord_platform_event.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import asyncio
|
||||
import discord
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, At
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
Reply,
|
||||
)
|
||||
from astrbot import logger
|
||||
from .client import DiscordBotClient
|
||||
from .components import DiscordEmbed, DiscordView
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 自定义Discord视图组件(兼容旧版本)
|
||||
class DiscordViewComponent(BaseMessageComponent):
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(self, view: discord.ui.View):
|
||||
self.view = view
|
||||
|
||||
|
||||
class DiscordPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: DiscordBotClient,
|
||||
interaction_followup_webhook: Optional[discord.Webhook] = None,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
self.interaction_followup_webhook = interaction_followup_webhook
|
||||
|
||||
@override
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息到Discord平台"""
|
||||
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
try:
|
||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
kwargs = {}
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
if files:
|
||||
kwargs["files"] = files
|
||||
if view:
|
||||
kwargs["view"] = view
|
||||
if embeds:
|
||||
kwargs["embeds"] = embeds
|
||||
if reference_message_id and not self.interaction_followup_webhook:
|
||||
kwargs["reference"] = self.client.get_message(int(reference_message_id))
|
||||
if not kwargs:
|
||||
logger.debug("[Discord] 尝试发送空消息,已忽略。")
|
||||
return
|
||||
|
||||
# 根据上下文执行发送/回复操作
|
||||
try:
|
||||
# -- 斜杠指令/交互上下文 --
|
||||
if self.interaction_followup_webhook:
|
||||
await self.interaction_followup_webhook.send(**kwargs)
|
||||
|
||||
# -- 常规消息上下文 --
|
||||
else:
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
return
|
||||
else:
|
||||
await channel.send(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def _get_channel(self) -> Optional[discord.abc.Messageable]:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
return self.client.get_channel(
|
||||
channel_id
|
||||
) or await self.client.fetch_channel(channel_id)
|
||||
except (ValueError, discord.errors.NotFound, discord.errors.Forbidden):
|
||||
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
|
||||
return None
|
||||
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
content = ""
|
||||
files = []
|
||||
view = None
|
||||
embeds = []
|
||||
reference_message_id = None
|
||||
for i in message.chain: # 遍历消息链
|
||||
if isinstance(i, Plain): # 如果是文字类型的
|
||||
content += i.text
|
||||
elif isinstance(i, Reply):
|
||||
reference_message_id = i.id
|
||||
elif isinstance(i, At):
|
||||
content += f"<@{i.qq}>"
|
||||
elif isinstance(i, Image):
|
||||
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
|
||||
try:
|
||||
filename = getattr(i, "filename", None)
|
||||
file_content = getattr(i, "file", None)
|
||||
|
||||
if not file_content:
|
||||
logger.warning(f"[Discord] Image 组件没有 file 属性: {i}")
|
||||
continue
|
||||
|
||||
discord_file = None
|
||||
|
||||
# 1. URL
|
||||
if file_content.startswith("http"):
|
||||
logger.debug(f"[Discord] 处理 URL 图片: {file_content}")
|
||||
embed = discord.Embed().set_image(url=file_content)
|
||||
embeds.append(embed)
|
||||
continue
|
||||
|
||||
# 2. File URI
|
||||
elif file_content.startswith("file:///"):
|
||||
logger.debug(f"[Discord] 处理 File URI: {file_content}")
|
||||
path = Path(file_content[8:])
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
discord_file = discord.File(
|
||||
BytesIO(file_bytes), filename=filename or path.name
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 图片文件不存在: {path}")
|
||||
|
||||
# 3. Base64 URI
|
||||
elif file_content.startswith("base64://"):
|
||||
logger.debug("[Discord] 处理 Base64 URI")
|
||||
b64_data = file_content.split("base64://", 1)[1]
|
||||
missing_padding = len(b64_data) % 4
|
||||
if missing_padding:
|
||||
b64_data += "=" * (4 - missing_padding)
|
||||
img_bytes = base64.b64decode(b64_data)
|
||||
discord_file = discord.File(
|
||||
BytesIO(img_bytes), filename=filename or "image.png"
|
||||
)
|
||||
|
||||
# 4. 裸 Base64 或本地路径
|
||||
else:
|
||||
try:
|
||||
logger.debug("[Discord] 尝试作为裸 Base64 处理")
|
||||
b64_data = file_content
|
||||
missing_padding = len(b64_data) % 4
|
||||
if missing_padding:
|
||||
b64_data += "=" * (4 - missing_padding)
|
||||
img_bytes = base64.b64decode(b64_data)
|
||||
discord_file = discord.File(
|
||||
BytesIO(img_bytes), filename=filename or "image.png"
|
||||
)
|
||||
except (ValueError, TypeError, base64.binascii.Error):
|
||||
logger.debug(
|
||||
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}"
|
||||
)
|
||||
path = Path(file_content)
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
discord_file = discord.File(
|
||||
BytesIO(file_bytes), filename=filename or path.name
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 图片文件不存在: {path}")
|
||||
|
||||
if discord_file:
|
||||
files.append(discord_file)
|
||||
|
||||
except Exception:
|
||||
# 使用 getattr 来安全地访问 i.file,以防 i 本身就是问题
|
||||
file_info = getattr(i, "file", "未知")
|
||||
logger.error(
|
||||
f"[Discord] 处理图片时发生未知严重错误: {file_info}",
|
||||
exc_info=True,
|
||||
)
|
||||
elif isinstance(i, File):
|
||||
try:
|
||||
file_path_str = await i.get_file()
|
||||
if file_path_str:
|
||||
path = Path(file_path_str)
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
files.append(
|
||||
discord.File(BytesIO(file_bytes),
|
||||
filename=i.name)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[Discord] 获取文件失败,路径不存在: {file_path_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 获取文件失败: {i.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}")
|
||||
elif isinstance(i, DiscordEmbed):
|
||||
# Discord Embed消息
|
||||
embeds.append(i.to_discord_embed())
|
||||
elif isinstance(i, DiscordView):
|
||||
# Discord视图组件(按钮、选择菜单等)
|
||||
view = i.to_discord_view()
|
||||
elif isinstance(i, DiscordViewComponent):
|
||||
# 如果消息链中包含Discord视图组件(兼容旧版本)
|
||||
if isinstance(i.view, discord.ui.View):
|
||||
view = i.view
|
||||
else:
|
||||
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
|
||||
|
||||
if len(content) > 2000:
|
||||
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
|
||||
content = content[:2000]
|
||||
return content, files, view, embeds, reference_message_id
|
||||
|
||||
async def react(self, emoji: str):
|
||||
"""对原消息添加反应"""
|
||||
try:
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "add_reaction"
|
||||
):
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
def is_slash_command(self) -> bool:
|
||||
"""判断是否为斜杠命令"""
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
def is_button_interaction(self) -> bool:
|
||||
"""判断是否为按钮交互"""
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
def is_mentioned(self) -> bool:
|
||||
"""判断机器人是否被@"""
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "mentions"
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
)
|
||||
return False
|
||||
|
||||
def get_mention_clean_content(self) -> str:
|
||||
"""获取去除@后的清洁内容"""
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "clean_content"
|
||||
):
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return self.message_str
|
||||
@@ -15,6 +15,7 @@ from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
@@ -250,7 +251,10 @@ class SimpleGewechatClient:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
@@ -458,8 +462,10 @@ class SimpleGewechatClient:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
if os.path.exists("data/temp/gewe_code"):
|
||||
with open("data/temp/gewe_code", "r") as f:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
@@ -470,9 +476,9 @@ class SimpleGewechatClient:
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
@@ -6,7 +6,7 @@ import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -21,6 +21,7 @@ from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
@@ -106,7 +107,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
# 根据 url 下载视频
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
@@ -115,7 +117,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
thumb_path = f"data/temp/gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
@@ -154,7 +159,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
@@ -173,7 +179,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
import lark_oapi as lark
|
||||
@@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -40,7 +42,8 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
# save as temp file
|
||||
file_path = f"data/temp/{uuid.uuid4()}_test.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
|
||||
162
astrbot/core/platform/sources/slack/client.py
Normal file
162
astrbot/core/platform/sources/slack/client.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
from quart import Quart, request, Response
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class SlackWebhookClient:
|
||||
"""Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_client: AsyncWebClient,
|
||||
signing_secret: str,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 3000,
|
||||
path: str = "/slack/events",
|
||||
event_handler: Optional[Callable] = None,
|
||||
):
|
||||
self.web_client = web_client
|
||||
self.signing_secret = signing_secret
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.event_handler = event_handler
|
||||
|
||||
self.app = Quart(__name__)
|
||||
self._setup_routes()
|
||||
|
||||
# 禁用 Quart 的默认日志输出
|
||||
logging.getLogger("quart.app").setLevel(logging.WARNING)
|
||||
logging.getLogger("quart.serving").setLevel(logging.WARNING)
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置路由"""
|
||||
|
||||
@self.app.route(self.path, methods=["POST"])
|
||||
async def slack_events():
|
||||
"""处理 Slack 事件"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await request.get_data()
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
timestamp = request.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = request.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
sig_basestring.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
if event_data.get("type") == "url_verification":
|
||||
return {"challenge": event_data.get("challenge")}
|
||||
# 处理事件
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
|
||||
@self.app.route("/health", methods=["GET"])
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "ok", "service": "slack-webhook"}
|
||||
|
||||
async def start(self):
|
||||
"""启动 Webhook 服务器"""
|
||||
logger.info(
|
||||
f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}..."
|
||||
)
|
||||
|
||||
await self.app.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
debug=False,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def stop(self):
|
||||
"""停止 Webhook 服务器"""
|
||||
self.shutdown_event.set()
|
||||
logger.info("Slack Webhook 服务器已停止")
|
||||
|
||||
|
||||
class SlackSocketClient:
|
||||
"""Slack Socket 模式客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
web_client: AsyncWebClient,
|
||||
app_token: str,
|
||||
event_handler: Optional[Callable] = None,
|
||||
):
|
||||
self.web_client = web_client
|
||||
self.app_token = app_token
|
||||
self.event_handler = event_handler
|
||||
self.socket_client = None
|
||||
|
||||
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
|
||||
"""处理 Socket Mode 事件"""
|
||||
try:
|
||||
# 确认收到事件
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
await self.socket_client.send_socket_mode_response(response)
|
||||
|
||||
# 处理事件
|
||||
if self.event_handler:
|
||||
await self.event_handler(req)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Socket Mode 事件时出错: {e}")
|
||||
|
||||
async def start(self):
|
||||
"""启动 Socket Mode 连接"""
|
||||
self.socket_client = SocketModeClient(
|
||||
app_token=self.app_token,
|
||||
logger=logger,
|
||||
web_client=self.web_client,
|
||||
)
|
||||
|
||||
# 注册事件处理器
|
||||
self.socket_client.socket_mode_request_listeners.append(self._handle_events)
|
||||
|
||||
logger.info("Slack Socket Mode 客户端启动中...")
|
||||
await self.socket_client.connect()
|
||||
|
||||
async def stop(self):
|
||||
"""停止 Socket Mode 连接"""
|
||||
if self.socket_client:
|
||||
await self.socket_client.disconnect()
|
||||
await self.socket_client.close()
|
||||
logger.info("Slack Socket Mode 客户端已停止")
|
||||
396
astrbot/core/platform/sources/slack/slack_adapter.py
Normal file
396
astrbot/core/platform/sources/slack/slack_adapter.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import aiohttp
|
||||
import re
|
||||
import base64
|
||||
from typing import Awaitable, Any
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from .slack_event import SlackMessageEvent
|
||||
from .client import SlackWebhookClient, SlackSocketClient
|
||||
from astrbot.api.message_components import * # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"slack", "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。"
|
||||
)
|
||||
class SlackAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.bot_token = platform_config.get("bot_token")
|
||||
self.app_token = platform_config.get("app_token")
|
||||
self.signing_secret = platform_config.get("signing_secret")
|
||||
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
|
||||
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
|
||||
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
|
||||
self.webhook_path = platform_config.get(
|
||||
"slack_webhook_path", "/astrbot-slack-webhook/callback"
|
||||
)
|
||||
|
||||
if not self.bot_token:
|
||||
raise ValueError("Slack bot_token 是必需的")
|
||||
|
||||
if self.connection_mode == "socket" and not self.app_token:
|
||||
raise ValueError("Socket Mode 需要 app_token")
|
||||
|
||||
if self.connection_mode == "webhook" and not self.signing_secret:
|
||||
raise ValueError("Webhook Mode 需要 signing_secret")
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
# 初始化 Slack Web Client
|
||||
self.web_client = AsyncWebClient(token=self.bot_token, logger=logger)
|
||||
self.socket_client = None
|
||||
self.webhook_client = None
|
||||
|
||||
self.bot_self_id = None
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
blocks, text = SlackMessageEvent._parse_slack_blocks(
|
||||
message_chain=message_chain, web_client=self.web_client
|
||||
)
|
||||
|
||||
try:
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
# 发送到频道
|
||||
channel_id = (
|
||||
session.session_id.split("_")[-1]
|
||||
if "_" in session.session_id
|
||||
else session.session_id
|
||||
)
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
text=text,
|
||||
blocks=blocks if blocks else None,
|
||||
)
|
||||
else:
|
||||
# 发送私信
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=session.session_id,
|
||||
text=text,
|
||||
blocks=blocks if blocks else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Slack 发送消息失败: {e}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, event: dict) -> AstrBotMessage:
|
||||
logger.debug(f"[slack] RawMessage {event}")
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_self_id
|
||||
|
||||
# 获取用户信息
|
||||
user_id = event.get("user", "")
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=user_id)
|
||||
user_data = user_info["user"]
|
||||
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
||||
except Exception:
|
||||
user_name = user_id
|
||||
|
||||
abm.sender = MessageMember(user_id=user_id, nickname=user_name)
|
||||
|
||||
# 判断消息类型
|
||||
channel_id = event.get("channel", "")
|
||||
try:
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
is_im = channel_info["channel"]["is_im"]
|
||||
|
||||
if is_im:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = channel_id
|
||||
except Exception:
|
||||
# 默认作为群组消息处理
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = channel_id
|
||||
|
||||
# 设置会话ID
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{user_id}_{channel_id}"
|
||||
else:
|
||||
abm.session_id = (
|
||||
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
|
||||
)
|
||||
|
||||
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
|
||||
abm.timestamp = int(float(event.get("ts", time.time())))
|
||||
|
||||
# 处理消息内容
|
||||
message_text = event.get("text", "")
|
||||
abm.message_str = message_text
|
||||
abm.message = []
|
||||
|
||||
# 优先使用 blocks 字段解析消息
|
||||
if "blocks" in event and event["blocks"]:
|
||||
abm.message = self._parse_blocks(event["blocks"])
|
||||
# 更新 message_str
|
||||
abm.message_str = ""
|
||||
for component in abm.message:
|
||||
if isinstance(component, Plain):
|
||||
abm.message_str += component.text
|
||||
elif message_text:
|
||||
# 处理传统的文本消息
|
||||
if "<@" in message_text:
|
||||
mentions = re.findall(r"<@([^>]+)>", message_text)
|
||||
for mention in mentions:
|
||||
try:
|
||||
mentioned_user = await self.web_client.users_info(user=mention)
|
||||
user_data = mentioned_user["user"]
|
||||
user_name = user_data.get("real_name") or user_data.get(
|
||||
"name", mention
|
||||
)
|
||||
abm.message.append(At(qq=mention, name=user_name))
|
||||
except Exception:
|
||||
abm.message.append(At(qq=mention, name=""))
|
||||
|
||||
# 清理消息文本中的@标记
|
||||
if clean_text := re.sub(r"<@[^>]+>", "", message_text).strip():
|
||||
abm.message.append(Plain(text=clean_text))
|
||||
else:
|
||||
abm.message.append(Plain(text=message_text))
|
||||
|
||||
# 处理文件附件
|
||||
if "files" in event:
|
||||
for file_info in event["files"]:
|
||||
file_name = file_info.get("name", "unknown")
|
||||
file_url = file_info.get("url_private", "")
|
||||
if file_info.get("mimetype", "").startswith("image/"):
|
||||
file_url = await self.get_file_base64(file_url)
|
||||
abm.message.append(Image.fromBase64(base64=file_url))
|
||||
else:
|
||||
# TODO: 下载鉴权
|
||||
abm.message.append(
|
||||
File(name=file_name, file=file_url, url=file_url)
|
||||
)
|
||||
|
||||
abm.raw_message = event
|
||||
return abm
|
||||
|
||||
def _parse_blocks(self, blocks: list) -> list:
|
||||
"""解析 Slack blocks 格式的消息内容"""
|
||||
message_components = []
|
||||
|
||||
for block in blocks:
|
||||
block_type = block.get("type", "")
|
||||
|
||||
if block_type == "rich_text":
|
||||
# 处理富文本块
|
||||
elements = block.get("elements", [])
|
||||
for element in elements:
|
||||
if element.get("type") == "rich_text_section":
|
||||
# 处理富文本段落
|
||||
section_elements = element.get("elements", [])
|
||||
text_content = ""
|
||||
|
||||
for section_element in section_elements:
|
||||
element_type = section_element.get("type", "")
|
||||
|
||||
if element_type == "text":
|
||||
# 普通文本
|
||||
text_content += section_element.get("text", "")
|
||||
elif element_type == "user":
|
||||
# @用户提及
|
||||
user_id = section_element.get("user_id", "")
|
||||
if user_id:
|
||||
# 将之前的文本内容先添加到组件中
|
||||
if text_content.strip():
|
||||
message_components.append(
|
||||
Plain(text=text_content)
|
||||
)
|
||||
text_content = ""
|
||||
# 添加@提及组件
|
||||
message_components.append(At(qq=user_id, name=""))
|
||||
elif element_type == "channel":
|
||||
# #频道提及
|
||||
channel_id = section_element.get("channel_id", "")
|
||||
text_content += f"#{channel_id}"
|
||||
elif element_type == "link":
|
||||
# 链接
|
||||
url = section_element.get("url", "")
|
||||
link_text = section_element.get("text", url)
|
||||
text_content += f"[{link_text}]({url})"
|
||||
elif element_type == "emoji":
|
||||
# 表情符号
|
||||
emoji_name = section_element.get("name", "")
|
||||
text_content += f":{emoji_name}:"
|
||||
|
||||
if text_content.strip():
|
||||
message_components.append(Plain(text=text_content))
|
||||
|
||||
elif element.get("type") == "rich_text_list":
|
||||
# 处理列表
|
||||
list_items = element.get("elements", [])
|
||||
list_text = ""
|
||||
for item in list_items:
|
||||
if item.get("type") == "rich_text_section":
|
||||
item_elements = item.get("elements", [])
|
||||
item_text = ""
|
||||
for item_element in item_elements:
|
||||
if item_element.get("type") == "text":
|
||||
item_text += item_element.get("text", "")
|
||||
list_text += f"• {item_text}\n"
|
||||
|
||||
if list_text.strip():
|
||||
message_components.append(Plain(text=list_text.strip()))
|
||||
|
||||
elif block_type == "section":
|
||||
# 处理段落块
|
||||
if "text" in block:
|
||||
text_obj = block["text"]
|
||||
if text_obj.get("type") == "mrkdwn":
|
||||
text_content = text_obj.get("text", "")
|
||||
message_components.append(Plain(text=text_content))
|
||||
|
||||
return message_components
|
||||
|
||||
async def _handle_socket_event(self, req: SocketModeRequest):
|
||||
"""处理 Socket Mode 事件"""
|
||||
if req.type == "events_api":
|
||||
# 事件 API
|
||||
event = req.payload.get("event", {})
|
||||
|
||||
# 忽略机器人自己的消息和消息编辑
|
||||
if event.get("subtype") in [
|
||||
"bot_message",
|
||||
"message_changed",
|
||||
"message_deleted",
|
||||
]:
|
||||
return
|
||||
|
||||
if event.get("bot_id"):
|
||||
return
|
||||
|
||||
if event.get("type") in ["message", "app_mention"]:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def get_bot_user_id(self):
|
||||
auth_info = await self.web_client.auth_test()
|
||||
return auth_info.get("user_id")
|
||||
|
||||
async def get_file_base64(self, url: str) -> str:
|
||||
"""下载 Slack 文件并返回 Base64 编码的内容"""
|
||||
headers = {"Authorization": f"Bearer {self.bot_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
content = await resp.read()
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
return base64_content
|
||||
else:
|
||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
self.bot_self_id = await self.get_bot_user_id()
|
||||
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
||||
|
||||
if self.connection_mode == "socket":
|
||||
if not self.app_token:
|
||||
raise ValueError("Socket Mode 需要 app_token")
|
||||
|
||||
# 创建 Socket 客户端
|
||||
self.socket_client = SlackSocketClient(
|
||||
self.web_client, self.app_token, self._handle_socket_event
|
||||
)
|
||||
|
||||
logger.info("Slack 适配器 (Socket Mode) 启动中...")
|
||||
await self.socket_client.start()
|
||||
|
||||
elif self.connection_mode == "webhook":
|
||||
if not self.signing_secret:
|
||||
raise ValueError("Webhook Mode 需要 signing_secret")
|
||||
|
||||
# 创建 Webhook 客户端
|
||||
self.webhook_client = SlackWebhookClient(
|
||||
self.web_client,
|
||||
self.signing_secret,
|
||||
self.webhook_host,
|
||||
self.webhook_port,
|
||||
self.webhook_path,
|
||||
self._handle_webhook_event,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}..."
|
||||
)
|
||||
await self.webhook_client.start()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'"
|
||||
)
|
||||
|
||||
async def _handle_webhook_event(self, event_data: dict):
|
||||
"""处理 Webhook 事件"""
|
||||
event = event_data.get("event", {})
|
||||
|
||||
# 忽略机器人自己的消息和消息编辑
|
||||
if event.get("subtype") in [
|
||||
"bot_message",
|
||||
"message_changed",
|
||||
"message_deleted",
|
||||
]:
|
||||
return
|
||||
|
||||
if event.get("bot_id"):
|
||||
return
|
||||
|
||||
if event.get("type") in ["message", "app_mention"]:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def terminate(self):
|
||||
if self.socket_client:
|
||||
await self.socket_client.stop()
|
||||
if self.webhook_client:
|
||||
await self.webhook_client.stop()
|
||||
logger.info("Slack 适配器已被优雅地关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = SlackMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
web_client=self.web_client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self):
|
||||
return self.web_client
|
||||
237
astrbot/core/platform/sources/slack/slack_event.py
Normal file
237
astrbot/core/platform/sources/slack/slack_event.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Plain,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class SlackMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str,
|
||||
message_obj,
|
||||
platform_meta,
|
||||
session_id,
|
||||
web_client: AsyncWebClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.web_client = web_client
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_slack_block(
|
||||
segment: BaseMessageComponent, web_client: AsyncWebClient
|
||||
) -> dict:
|
||||
"""将消息段转换为 Slack 块格式"""
|
||||
if isinstance(segment, Plain):
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
|
||||
elif isinstance(segment, Image):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
if url.startswith("http"):
|
||||
return {
|
||||
"type": "image",
|
||||
"image_url": url,
|
||||
"alt_text": "图片",
|
||||
}
|
||||
path = await segment.convert_to_file_path()
|
||||
response = await web_client.files_upload_v2(
|
||||
file=path,
|
||||
filename="image.jpg",
|
||||
)
|
||||
if not response["ok"]:
|
||||
logger.error(f"Slack file upload failed: {response['error']}")
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
||||
}
|
||||
image_url = response["files"][0]["url_private"]
|
||||
logger.debug(f"Slack file upload response: {response}")
|
||||
return {
|
||||
"type": "image",
|
||||
"slack_file": {
|
||||
"url": image_url,
|
||||
},
|
||||
"alt_text": "图片",
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
response = await web_client.files_upload_v2(
|
||||
file=url,
|
||||
filename=segment.name or "file",
|
||||
)
|
||||
if not response["ok"]:
|
||||
logger.error(f"Slack file upload failed: {response['error']}")
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
||||
else:
|
||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||
|
||||
@staticmethod
|
||||
async def _parse_slack_blocks(
|
||||
message_chain: MessageChain, web_client: AsyncWebClient
|
||||
):
|
||||
"""解析成 Slack 块格式"""
|
||||
blocks = []
|
||||
text_content = ""
|
||||
|
||||
for segment in message_chain.chain:
|
||||
if isinstance(segment, Plain):
|
||||
text_content += segment.text
|
||||
else:
|
||||
# 如果有文本内容,先添加文本块
|
||||
if text_content.strip():
|
||||
blocks.append(
|
||||
{
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": text_content},
|
||||
}
|
||||
)
|
||||
text_content = ""
|
||||
|
||||
# 添加其他类型的块
|
||||
block = await SlackMessageEvent._from_segment_to_slack_block(
|
||||
segment, web_client
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
# 如果最后还有文本内容
|
||||
if text_content.strip():
|
||||
blocks.append(
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": text_content}}
|
||||
)
|
||||
|
||||
return blocks, "" if blocks else text_content
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
blocks, text = await SlackMessageEvent._parse_slack_blocks(
|
||||
message, self.web_client
|
||||
)
|
||||
|
||||
try:
|
||||
if self.get_group_id():
|
||||
# 发送到频道
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_group_id(),
|
||||
text=text,
|
||||
blocks=blocks or None,
|
||||
)
|
||||
else:
|
||||
# 发送私信
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_sender_id(),
|
||||
text=text,
|
||||
blocks=blocks or None,
|
||||
)
|
||||
except Exception:
|
||||
# 如果块发送失败,尝试只发送文本
|
||||
fallback_text = ""
|
||||
for segment in message.chain:
|
||||
if isinstance(segment, Plain):
|
||||
fallback_text += segment.text
|
||||
elif isinstance(segment, File):
|
||||
fallback_text += f" [文件: {segment.name}] "
|
||||
elif isinstance(segment, Image):
|
||||
fallback_text += " [图片] "
|
||||
|
||||
if self.get_group_id():
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_group_id(), text=fallback_text
|
||||
)
|
||||
else:
|
||||
await self.web_client.chat_postMessage(
|
||||
channel=self.get_sender_id(), text=fallback_text
|
||||
)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if group_id:
|
||||
channel_id = group_id
|
||||
elif self.get_group_id():
|
||||
channel_id = self.get_group_id()
|
||||
else:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 获取频道信息
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
|
||||
# 获取频道成员
|
||||
members_response = await self.web_client.conversations_members(
|
||||
channel=channel_id
|
||||
)
|
||||
|
||||
members = []
|
||||
for member_id in members_response["members"]:
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=member_id)
|
||||
user_data = user_info["user"]
|
||||
members.append(
|
||||
MessageMember(
|
||||
user_id=member_id,
|
||||
nickname=user_data.get("real_name")
|
||||
or user_data.get("name", member_id),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# 如果获取用户信息失败,使用默认信息
|
||||
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
||||
|
||||
channel_data = channel_info["channel"]
|
||||
return Group(
|
||||
group_id=channel_id,
|
||||
group_name=channel_data.get("name", ""),
|
||||
group_avatar="",
|
||||
group_admins=[], # Slack 的管理员信息需要特殊权限获取
|
||||
group_owner=channel_data.get("creator", ""),
|
||||
members=members,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -144,8 +144,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
command_dict = {}
|
||||
skip_commands = {"start"}
|
||||
|
||||
for handler_md in star_handlers_registry._handlers:
|
||||
handler_metadata = handler_md[1]
|
||||
for handler_md in star_handlers_registry:
|
||||
handler_metadata = handler_md
|
||||
if not star_map[handler_metadata.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_metadata.event_filters:
|
||||
@@ -282,10 +282,12 @@ class TelegramPlatformAdapter(Platform):
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
plain_text = (
|
||||
plain_text[: entity.offset]
|
||||
+ plain_text[entity.offset + entity.length :]
|
||||
)
|
||||
# 如果mention是当前bot则移除;否则保留
|
||||
if name.lower() == context.bot.username.lower():
|
||||
plain_text = (
|
||||
plain_text[: entity.offset]
|
||||
+ plain_text[entity.offset + entity.length :]
|
||||
)
|
||||
|
||||
if plain_text:
|
||||
message.message.append(Comp.Plain(plain_text))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -13,9 +15,20 @@ from astrbot.api.message_components import (
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
# Telegram 的最大消息长度限制
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
SPLIT_PATTERNS = {
|
||||
"paragraph": re.compile(r"\n\n"),
|
||||
"line": re.compile(r"\n"),
|
||||
"sentence": re.compile(r"[.!?。!?]"),
|
||||
"word": re.compile(r"\s"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
@@ -27,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
|
||||
def _split_message(self, text: str) -> list[str]:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
chunks.append(text)
|
||||
break
|
||||
|
||||
split_point = self.MAX_MESSAGE_LENGTH
|
||||
segment = text[: self.MAX_MESSAGE_LENGTH]
|
||||
|
||||
for _, pattern in self.SPLIT_PATTERNS.items():
|
||||
if matches := list(pattern.finditer(segment)):
|
||||
last_match = matches[-1]
|
||||
split_point = last_match.end()
|
||||
break
|
||||
|
||||
chunks.append(text[:split_point])
|
||||
text = text[split_point:].lstrip()
|
||||
|
||||
return chunks
|
||||
|
||||
async def send_with_client(
|
||||
self, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
@@ -57,25 +95,29 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
if isinstance(i, Plain):
|
||||
if at_user_id and not at_flag:
|
||||
i.text = f"@{at_user_id} " + i.text
|
||||
i.text = f"@{at_user_id} {i.text}"
|
||||
at_flag = True
|
||||
text = i.text
|
||||
try:
|
||||
text = telegramify_markdown.markdownify(
|
||||
i.text, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||
)
|
||||
return
|
||||
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||
chunks = self._split_message(i.text)
|
||||
for chunk in chunks:
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text, parse_mode="MarkdownV2", **payload
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead."
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -126,7 +168,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -143,17 +186,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
|
||||
# Plain
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
else:
|
||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
@@ -172,6 +205,18 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
delta = ""
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
|
||||
try:
|
||||
if delta and current_content != delta:
|
||||
|
||||
@@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class QueueListener:
|
||||
@@ -40,7 +41,8 @@ class WebChatAdapter(Platform):
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id")
|
||||
|
||||
@@ -6,8 +6,9 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
imgs_dir = "data/webchat/imgs"
|
||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
@@ -0,0 +1,916 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import websockets
|
||||
from astrbot import logger
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api.platform import Platform, PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .wechatpadpro_message_event import WeChatPadProMessageEvent
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||
class WeChatPadProAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wechatpadpro",
|
||||
description="WeChatPadPro 消息平台适配器",
|
||||
id=self.config.get("id", "wechatpadpro"),
|
||||
)
|
||||
|
||||
# 保存配置信息
|
||||
self.admin_key = self.config.get("admin_key")
|
||||
self.host = self.config.get("host")
|
||||
self.port = self.config.get("port")
|
||||
self.active_mesasge_poll: bool = self.config.get(
|
||||
"wpp_active_message_poll", False
|
||||
)
|
||||
self.active_message_poll_interval: int = self.config.get(
|
||||
"wpp_active_message_poll_interval", 5
|
||||
)
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(), "wechatpadpro_credentials.json"
|
||||
) # 持久化文件路径
|
||||
self.ws_handle_task = None
|
||||
|
||||
# 添加图片消息缓存,用于引用消息处理
|
||||
self.cached_images = {}
|
||||
"""缓存图片消息。key是NewMsgId (对应引用消息的svrid),value是图片的base64数据"""
|
||||
# 设置缓存大小限制,避免内存占用过大
|
||||
self.max_image_cache = 50
|
||||
|
||||
# 添加文本消息缓存,用于引用消息处理
|
||||
self.cached_texts = {}
|
||||
"""缓存文本消息。key是NewMsgId (对应引用消息的svrid),value是消息文本内容"""
|
||||
# 设置文本缓存大小限制
|
||||
self.max_text_cache = 100
|
||||
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
启动平台适配器的运行实例。
|
||||
"""
|
||||
logger.info("WeChatPadPro 适配器正在启动...")
|
||||
|
||||
if loaded_credentials := self.load_credentials():
|
||||
self.auth_key = loaded_credentials.get("auth_key")
|
||||
self.wxid = loaded_credentials.get("wxid")
|
||||
|
||||
isLoginIn = await self.check_online_status()
|
||||
|
||||
# 检查在线状态
|
||||
if self.auth_key and isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
|
||||
# 如果在线,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
else:
|
||||
# 1. 生成授权码
|
||||
if not self.auth_key:
|
||||
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
|
||||
await self.generate_auth_key()
|
||||
|
||||
# 2. 获取登录二维码
|
||||
if not isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
|
||||
qr_code_url = await self.get_login_qr_code()
|
||||
|
||||
if qr_code_url:
|
||||
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
|
||||
else:
|
||||
logger.error("无法获取登录二维码。")
|
||||
return
|
||||
|
||||
# 3. 检测扫码状态
|
||||
login_successful = await self.check_login_status()
|
||||
|
||||
if login_successful:
|
||||
logger.info("登录成功,WeChatPadPro适配器已连接。")
|
||||
else:
|
||||
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
|
||||
await self.terminate()
|
||||
return
|
||||
|
||||
# 登录成功后,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
|
||||
self._shutdown_event = asyncio.Event()
|
||||
await self._shutdown_event.wait()
|
||||
logger.info("WeChatPadPro 适配器已停止。")
|
||||
|
||||
def load_credentials(self):
|
||||
"""
|
||||
从文件中加载 auth_key 和 wxid。
|
||||
"""
|
||||
if os.path.exists(self.credentials_file):
|
||||
try:
|
||||
with open(self.credentials_file, "r") as f:
|
||||
credentials = json.load(f)
|
||||
logger.info("成功加载 WeChatPadPro 凭据。")
|
||||
return credentials
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WeChatPadPro 凭据失败: {e}")
|
||||
return None
|
||||
|
||||
def save_credentials(self):
|
||||
"""
|
||||
将 auth_key 和 wxid 保存到文件。
|
||||
"""
|
||||
credentials = {
|
||||
"auth_key": self.auth_key,
|
||||
"wxid": self.wxid,
|
||||
}
|
||||
try:
|
||||
# 确保数据目录存在
|
||||
data_dir = os.path.dirname(self.credentials_file)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
with open(self.credentials_file, "w") as f:
|
||||
json.dump(credentials, f)
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
|
||||
|
||||
async def check_online_status(self):
|
||||
"""
|
||||
检查 WeChatPadPro 设备是否在线。
|
||||
"""
|
||||
if not self.auth_key:
|
||||
return False
|
||||
url = f"{self.base_url}/login/GetLoginStatus"
|
||||
params = {"key": self.auth_key}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response_data = await response.json()
|
||||
# 根据提供的在线接口返回示例,成功状态码是 200,loginState 为 1 表示在线
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
login_state = response_data.get("Data", {}).get("loginState")
|
||||
if login_state == 1:
|
||||
logger.info("WeChatPadPro 设备当前在线。")
|
||||
return True
|
||||
# login_state == 3 为离线状态
|
||||
elif login_state == 3:
|
||||
logger.info("WeChatPadPro 设备不在线。")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"未知的在线状态: {response_data}")
|
||||
return False
|
||||
# Code == 300 为微信退出状态。
|
||||
elif response.status == 200 and response_data.get("Code") == 300:
|
||||
logger.info("WeChatPadPro 设备已退出。")
|
||||
return False
|
||||
elif response.status == 200 and response_data.get("Code") == -2:
|
||||
# 该链接不存在
|
||||
self.auth_key = None
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检查在线状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
return False
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态时发生错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def generate_auth_key(self):
|
||||
"""
|
||||
生成授权码。
|
||||
"""
|
||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||
params = {"key": self.admin_key}
|
||||
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
# 修正成功判断条件和授权码提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 授权码在 Data 字段的列表中
|
||||
if (
|
||||
response_data.get("Data")
|
||||
and isinstance(response_data["Data"], list)
|
||||
and len(response_data["Data"]) > 0
|
||||
):
|
||||
self.auth_key = response_data["Data"][0]
|
||||
logger.info(f"成功获取授权码 {self.auth_key[:8]}...")
|
||||
else:
|
||||
logger.error(
|
||||
f"生成授权码成功但未找到授权码: {response_data}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"生成授权码失败: {response.status}, {response_data}"
|
||||
)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"生成授权码时发生错误: {e}")
|
||||
|
||||
async def get_login_qr_code(self):
|
||||
"""
|
||||
获取登录二维码地址。
|
||||
"""
|
||||
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {} # 根据文档,这个接口的 body 可以为空
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 二维码地址在 Data.QrCodeUrl 字段中
|
||||
if response_data.get("Data") and response_data["Data"].get(
|
||||
"QrCodeUrl"
|
||||
):
|
||||
return response_data["Data"]["QrCodeUrl"]
|
||||
else:
|
||||
logger.error(
|
||||
f"获取登录二维码成功但未找到二维码地址: {response_data}"
|
||||
)
|
||||
return None
|
||||
elif "该 key 无效" in response_data.get("Text"):
|
||||
logger.error(
|
||||
"授权码无效,已经清除。请重新启动 AstrBot 或者本消息适配器。原因也可能是 WeChatPadPro 的 MySQL 服务没有启动成功,请检查 WeChatPadPro 服务的日志。"
|
||||
)
|
||||
self.auth_key = None
|
||||
self.save_credentials()
|
||||
return None
|
||||
else:
|
||||
logger.error(
|
||||
f"获取登录二维码失败: {response.status}, {response_data}"
|
||||
)
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取登录二维码时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def check_login_status(self):
|
||||
"""
|
||||
循环检测扫码状态。
|
||||
尝试 6 次后跳出循环,添加倒计时。
|
||||
返回 True 如果登录成功,否则返回 False。
|
||||
"""
|
||||
url = f"{self.base_url}/login/CheckLoginStatus"
|
||||
params = {"key": self.auth_key}
|
||||
|
||||
attempts = 0 # 初始化尝试次数
|
||||
max_attempts = 36 # 最大尝试次数
|
||||
countdown = 180 # 倒计时时长
|
||||
logger.info(f"请在 {countdown} 秒内扫码登录。")
|
||||
while attempts < max_attempts:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response_data = await response.json()
|
||||
# 成功判断条件和数据提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
if (
|
||||
response_data.get("Data")
|
||||
and response_data["Data"].get("state") is not None
|
||||
):
|
||||
status = response_data["Data"]["state"]
|
||||
logger.info(
|
||||
f"第 {attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}秒"
|
||||
)
|
||||
if status == 2: # 状态 2 表示登录成功
|
||||
self.wxid = response_data["Data"].get("wxid")
|
||||
self.wxnewpass = response_data["Data"].get(
|
||||
"wxnewpass"
|
||||
)
|
||||
logger.info(
|
||||
f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}"
|
||||
)
|
||||
self.save_credentials() # 登录成功后保存凭据
|
||||
return True
|
||||
elif status == -2: # 二维码过期
|
||||
logger.error("二维码已过期,请重新获取。")
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检测登录状态成功但未找到登录状态: {response_data}"
|
||||
)
|
||||
elif response_data.get("Code") == 300:
|
||||
# "不存在状态"
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
f"检测登录状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
await asyncio.sleep(5)
|
||||
attempts += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"检测登录状态时发生错误: {e}")
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
attempts += 1
|
||||
await asyncio.sleep(5) # 每隔5秒检测一次
|
||||
logger.warning("登录检测超过最大尝试次数,退出检测。")
|
||||
return False
|
||||
|
||||
async def connect_websocket(self):
|
||||
"""
|
||||
建立 WebSocket 连接并处理接收到的消息。
|
||||
"""
|
||||
os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}"
|
||||
ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}"
|
||||
logger.info(
|
||||
f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
logger.debug("WebSocket 连接成功。")
|
||||
# 设置空闲超时重连
|
||||
wait_time = (
|
||||
self.active_message_poll_interval
|
||||
if self.active_mesasge_poll
|
||||
else 120
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(), timeout=wait_time
|
||||
)
|
||||
# logger.debug(message) # 不显示原始消息内容
|
||||
asyncio.create_task(self.handle_websocket_message(message))
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"WebSocket 连接空闲超过 {wait_time} s")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
logger.info("WebSocket 连接正常关闭。")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
"""
|
||||
处理从 WebSocket 接收到的消息。
|
||||
"""
|
||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||
try:
|
||||
message_data = json.loads(message)
|
||||
if (
|
||||
message_data.get("msg_id") is not None
|
||||
and message_data.get("from_user_name") is not None
|
||||
):
|
||||
abm = await self.convert_message(message_data)
|
||||
if abm:
|
||||
# 创建 WeChatPadProMessageEvent 实例
|
||||
message_event = WeChatPadProMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
# 传递适配器实例,以便在事件中调用 send 方法
|
||||
adapter=self,
|
||||
)
|
||||
# 提交事件到事件队列
|
||||
self.commit_event(message_event)
|
||||
else:
|
||||
logger.warning(f"收到未知结构的 WebSocket 消息: {message_data}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析 WebSocket 消息为 JSON: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
|
||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||
"""
|
||||
将 WeChatPadPro 原始消息转换为 AstrBotMessage。
|
||||
"""
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = raw_message
|
||||
abm.message_id = str(raw_message.get("msg_id"))
|
||||
abm.timestamp = raw_message.get("create_time")
|
||||
abm.self_id = self.wxid
|
||||
|
||||
if int(time.time()) - abm.timestamp > 180:
|
||||
logger.warning(
|
||||
f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}。"
|
||||
)
|
||||
return None
|
||||
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
content = raw_message.get("content", {}).get("str", "")
|
||||
push_content = raw_message.get("push_content", "")
|
||||
msg_type = raw_message.get("msg_type")
|
||||
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
|
||||
# 如果是机器人自己发送的消息、回显消息或系统消息,忽略
|
||||
if from_user_name == self.wxid:
|
||||
logger.info("忽略来自自己的消息。")
|
||||
return None
|
||||
|
||||
if from_user_name in ["weixin", "newsapp", "newsapp_wechat"]:
|
||||
logger.info("忽略来自微信团队的消息。")
|
||||
return None
|
||||
|
||||
# 先判断群聊/私聊并设置基本属性
|
||||
if await self._process_chat_type(
|
||||
abm, raw_message, from_user_name, to_user_name, content, push_content
|
||||
):
|
||||
# 再根据消息类型处理消息内容
|
||||
await self._process_message_content(abm, raw_message, msg_type, content)
|
||||
|
||||
return abm
|
||||
return None
|
||||
|
||||
async def _process_chat_type(
|
||||
self,
|
||||
abm: AstrBotMessage,
|
||||
raw_message: dict,
|
||||
from_user_name: str,
|
||||
to_user_name: str,
|
||||
content: str,
|
||||
push_content: str,
|
||||
):
|
||||
"""
|
||||
判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。
|
||||
"""
|
||||
if from_user_name == "weixin":
|
||||
return False
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = from_user_name
|
||||
|
||||
parts = content.split(":\n", 1)
|
||||
sender_wxid = parts[0] if len(parts) == 2 else ""
|
||||
abm.sender = MessageMember(user_id=sender_wxid, nickname="")
|
||||
|
||||
# 获取群聊发送者的nickname
|
||||
if sender_wxid:
|
||||
accurate_nickname = await self._get_group_member_nickname(
|
||||
abm.group_id, sender_wxid
|
||||
)
|
||||
if accurate_nickname:
|
||||
abm.sender.nickname = accurate_nickname
|
||||
|
||||
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
|
||||
if self.unique_session:
|
||||
abm.session_id = f"{from_user_name}#{abm.sender.user_id}"
|
||||
else:
|
||||
abm.session_id = from_user_name
|
||||
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if self.wxid in msg_source:
|
||||
at_me = True
|
||||
if "在群聊中@了你" in raw_message.get("push_content", ""):
|
||||
at_me = True
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=""))
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.group_id = ""
|
||||
nick_name = ""
|
||||
if push_content and " : " in push_content:
|
||||
nick_name = push_content.split(" : ")[0]
|
||||
abm.sender = MessageMember(user_id=from_user_name, nickname=nick_name)
|
||||
abm.session_id = from_user_name
|
||||
return True
|
||||
|
||||
async def _get_group_member_nickname(
|
||||
self, group_id: str, member_wxid: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
通过接口获取群成员的昵称。
|
||||
"""
|
||||
url = f"{self.base_url}/group/GetChatroomMemberDetail"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"ChatRoomName": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 从返回数据中查找对应成员的昵称
|
||||
member_list = (
|
||||
response_data.get("Data", {})
|
||||
.get("member_data", {})
|
||||
.get("chatroom_member_list", [])
|
||||
)
|
||||
for member in member_list:
|
||||
if member.get("user_name") == member_wxid:
|
||||
return member.get("nick_name")
|
||||
logger.warning(
|
||||
f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"获取群成员详情失败: {response.status}, {response_data}"
|
||||
)
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群成员详情时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _download_raw_image(
|
||||
self, from_user_name: str, to_user_name: str, msg_id: int
|
||||
):
|
||||
"""下载原始图片。"""
|
||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"CompressType": 0,
|
||||
"FromUserName": from_user_name,
|
||||
"MsgId": msg_id,
|
||||
"Section": {"DataLen": 61440, "StartPos": 0},
|
||||
"ToUserName": to_user_name,
|
||||
"TotalLen": 0,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(f"下载图片失败: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def download_voice(
|
||||
self, to_user_name: str, new_msg_id: str, bufid: str, length: int
|
||||
):
|
||||
"""下载原始音频。"""
|
||||
url = f"{self.base_url}/message/GetMsgVoice"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"Bufid": bufid,
|
||||
"ToUserName": to_user_name,
|
||||
"NewMsgId": new_msg_id,
|
||||
"Length": length,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.error(f"下载音频失败: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _process_message_content(
|
||||
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
|
||||
):
|
||||
"""
|
||||
根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。
|
||||
"""
|
||||
if msg_type == 1: # 文本消息
|
||||
abm.message_str = content
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
parts = content.split(":\n", 1)
|
||||
if len(parts) == 2:
|
||||
message_content = parts[1]
|
||||
abm.message_str = message_content
|
||||
|
||||
# 检查是否@了机器人,参考 gewechat 的实现方式
|
||||
# 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符(四分之一空格)
|
||||
at_me = False
|
||||
|
||||
# 检查 msg_source 中是否包含机器人的 wxid
|
||||
# wechatpadpro 的格式: <atuserlist>wxid</atuserlist>
|
||||
# gewechat 的格式: <atuserlist><![CDATA[wxid]]></atuserlist>
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if (
|
||||
f"<atuserlist>{abm.self_id}</atuserlist>" in msg_source
|
||||
or f"<atuserlist>{abm.self_id}," in msg_source
|
||||
or f",{abm.self_id}</atuserlist>" in msg_source
|
||||
):
|
||||
at_me = True
|
||||
|
||||
# 也检查 push_content 中是否有@提示
|
||||
push_content = raw_message.get("push_content", "")
|
||||
if "在群聊中@了你" in push_content:
|
||||
at_me = True
|
||||
|
||||
if at_me:
|
||||
# 被@了,在消息开头插入At组件(参考gewechat的做法)
|
||||
bot_nickname = await self._get_group_member_nickname(
|
||||
abm.group_id, abm.self_id
|
||||
)
|
||||
abm.message.insert(
|
||||
0, At(qq=abm.self_id, name=bot_nickname or abm.self_id)
|
||||
)
|
||||
|
||||
# 只有当消息内容不仅仅是@时才添加Plain组件
|
||||
if "\u2005" in message_content:
|
||||
# 检查@之后是否还有其他内容
|
||||
parts = message_content.split("\u2005")
|
||||
if len(parts) > 1 and any(
|
||||
part.strip() for part in parts[1:]
|
||||
):
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
# 检查是否只包含@机器人
|
||||
is_pure_at = False
|
||||
if (
|
||||
bot_nickname
|
||||
and message_content.strip() == f"@{bot_nickname}"
|
||||
):
|
||||
is_pure_at = True
|
||||
if not is_pure_at:
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
# 没有@机器人,作为普通文本处理
|
||||
abm.message.append(Plain(message_content))
|
||||
else:
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
else: # 私聊消息
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
|
||||
# 缓存文本消息,以便引用消息可以查找
|
||||
try:
|
||||
# 获取msg_id作为缓存的key
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id:
|
||||
# 限制缓存大小
|
||||
if (
|
||||
len(self.cached_texts) >= self.max_text_cache
|
||||
and self.cached_texts
|
||||
):
|
||||
# 删除最早的一条缓存
|
||||
oldest_key = next(iter(self.cached_texts))
|
||||
self.cached_texts.pop(oldest_key)
|
||||
|
||||
logger.debug(f"缓存文本消息,new_msg_id={new_msg_id}")
|
||||
self.cached_texts[str(new_msg_id)] = content
|
||||
except Exception as e:
|
||||
logger.error(f"缓存文本消息失败: {e}")
|
||||
elif msg_type == 3:
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
msg_id = raw_message.get("msg_id")
|
||||
image_resp = await self._download_raw_image(
|
||||
from_user_name, to_user_name, msg_id
|
||||
)
|
||||
image_bs64_data = (
|
||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||
)
|
||||
if image_bs64_data:
|
||||
abm.message.append(Image.fromBase64(image_bs64_data))
|
||||
# 缓存图片,以便引用消息可以查找
|
||||
try:
|
||||
# 获取msg_id作为缓存的key
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id:
|
||||
# 限制缓存大小
|
||||
if (
|
||||
len(self.cached_images) >= self.max_image_cache
|
||||
and self.cached_images
|
||||
):
|
||||
# 删除最早的一条缓存
|
||||
oldest_key = next(iter(self.cached_images))
|
||||
self.cached_images.pop(oldest_key)
|
||||
|
||||
logger.debug(f"缓存图片消息,new_msg_id={new_msg_id}")
|
||||
self.cached_images[str(new_msg_id)] = image_bs64_data
|
||||
except Exception as e:
|
||||
logger.error(f"缓存图片消息失败: {e}")
|
||||
elif msg_type == 47:
|
||||
# 视频消息 (注意:表情消息也是 47,需要区分)
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
emoji_message = data_parser.parse_emoji()
|
||||
if emoji_message is not None:
|
||||
abm.message.append(emoji_message)
|
||||
elif msg_type == 50:
|
||||
logger.warning("收到语音/视频消息,待实现。")
|
||||
elif msg_type == 34:
|
||||
# 语音消息
|
||||
bufid = 0
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
|
||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||
bufid = voicemsg.get("bufid") or "0"
|
||||
length = int(voicemsg.get("length") or 0)
|
||||
voice_resp = await self.download_voice(
|
||||
to_user_name=to_user_name,
|
||||
new_msg_id=new_msg_id,
|
||||
bufid=bufid,
|
||||
length=length,
|
||||
)
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_bs64_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
elif msg_type == 49:
|
||||
try:
|
||||
parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
cached_texts=self.cached_texts,
|
||||
cached_images=self.cached_images,
|
||||
raw_message=raw_message,
|
||||
downloader=self._download_raw_image,
|
||||
)
|
||||
components = await parser.parse_mutil_49()
|
||||
if components:
|
||||
abm.message.extend(components)
|
||||
abm.message_str = "\n".join(
|
||||
c.text for c in components if isinstance(c, Plain)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"msg_type 49 处理失败: {e}")
|
||||
abm.message.append(Plain("[XML 消息处理失败]"))
|
||||
abm.message_str = "[XML 消息处理失败]"
|
||||
else:
|
||||
logger.warning(f"收到未处理的消息类型: {msg_type}。")
|
||||
|
||||
async def terminate(self):
|
||||
"""
|
||||
终止一个平台的运行实例。
|
||||
"""
|
||||
logger.info("终止 WeChatPadPro 适配器。")
|
||||
try:
|
||||
if self.ws_handle_task:
|
||||
self.ws_handle_task.cancel()
|
||||
self._shutdown_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""
|
||||
得到一个平台的元数据。
|
||||
"""
|
||||
return self.metadata
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
dummy_message_obj = AstrBotMessage()
|
||||
dummy_message_obj.session_id = session.session_id
|
||||
# 根据 session_id 判断消息类型
|
||||
if "@chatroom" in session.session_id:
|
||||
dummy_message_obj.type = MessageType.GROUP_MESSAGE
|
||||
if "#" in session.session_id:
|
||||
dummy_message_obj.group_id = session.session_id.split("#")[0]
|
||||
else:
|
||||
dummy_message_obj.group_id = session.session_id
|
||||
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||
else:
|
||||
dummy_message_obj.type = MessageType.FRIEND_MESSAGE
|
||||
dummy_message_obj.group_id = ""
|
||||
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||
sending_event = WeChatPadProMessageEvent(
|
||||
message_str="",
|
||||
message_obj=dummy_message_obj,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
adapter=self,
|
||||
)
|
||||
# 调用实例方法 send
|
||||
await sending_event.send(message_chain)
|
||||
|
||||
async def get_contact_list(self):
|
||||
"""
|
||||
获取联系人列表。
|
||||
"""
|
||||
url = f"{self.base_url}/friend/GetContactList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = (
|
||||
result.get("Data", {})
|
||||
.get("ContactList", {})
|
||||
.get("contactUsernameList", [])
|
||||
)
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人列表时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def get_contact_details_list(
|
||||
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
获取联系人详情列表。
|
||||
"""
|
||||
if room_wx_id_list is None:
|
||||
room_wx_id_list = []
|
||||
if user_names is None:
|
||||
user_names = []
|
||||
url = f"{self.base_url}/friend/GetContactDetailsList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人详情列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = result.get("Data", {}).get("contactList", {})
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人详情列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人详情列表时发生错误: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,161 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image as PILImage # 使用别名避免冲突
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import (
|
||||
Image,
|
||||
Plain,
|
||||
WechatEmoji,
|
||||
Record,
|
||||
) # Import Image
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||
|
||||
|
||||
class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
adapter: "WeChatPadProAdapter", # 传递适配器实例
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.message_obj = message_obj # Save the full message object
|
||||
self.adapter = adapter # Save the adapter instance
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for comp in message.chain:
|
||||
await asyncio.sleep(1)
|
||||
if isinstance(comp, Plain):
|
||||
await self._send_text(session, comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
await self._send_image(session, comp)
|
||||
elif isinstance(comp, WechatEmoji):
|
||||
await self._send_emoji(session, comp)
|
||||
elif isinstance(comp, Record):
|
||||
await self._send_voice(session, comp)
|
||||
await super().send(message)
|
||||
|
||||
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||
b64 = await comp.convert_to_base64()
|
||||
raw = self._validate_base64(b64)
|
||||
b64c = self._compress_image(raw)
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendImageNewMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_text(self, session: aiohttp.ClientSession, text: str):
|
||||
if (
|
||||
self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息
|
||||
and self.adapter.settings.get(
|
||||
"reply_with_mention", False
|
||||
) # 检查适配器设置是否启用 reply_with_mention
|
||||
and self.message_obj.sender # 确保有发送者信息
|
||||
and (
|
||||
self.message_obj.sender.user_id or self.message_obj.sender.nickname
|
||||
) # 确保发送者有 ID 或昵称
|
||||
):
|
||||
# 优先使用 nickname,如果没有则使用 user_id
|
||||
mention_text = (
|
||||
self.message_obj.sender.nickname or self.message_obj.sender.user_id
|
||||
)
|
||||
message_text = f"@{mention_text} {text}"
|
||||
# logger.info(f"已添加 @ 信息: {message_text}")
|
||||
else:
|
||||
message_text = text
|
||||
if self.get_group_id() and "#" in self.session_id:
|
||||
session_id = self.session_id.split("#")[0]
|
||||
else:
|
||||
session_id = self.session_id
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{
|
||||
"MsgType": 1,
|
||||
"TextContent": message_text,
|
||||
"ToUserName": session_id,
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendTextMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji):
|
||||
payload = {
|
||||
"EmojiList": [
|
||||
{
|
||||
"EmojiMd5": comp.md5,
|
||||
"EmojiSize": comp.md5_len,
|
||||
"ToUserName": self.session_id,
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendEmojiMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 默认已经存在 data/temp 中
|
||||
b64, duration = await wav_to_tencent_silk_base64(record_path)
|
||||
payload = {
|
||||
"ToUserName": self.session_id,
|
||||
"VoiceData": b64,
|
||||
"VoiceFormat": 4,
|
||||
"VoiceSecond": duration,
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendVoice"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
@staticmethod
|
||||
def _validate_base64(b64: str) -> bytes:
|
||||
return base64.b64decode(b64, validate=True)
|
||||
|
||||
@staticmethod
|
||||
def _compress_image(data: bytes) -> str:
|
||||
img = PILImage.open(io.BytesIO(data))
|
||||
buf = io.BytesIO()
|
||||
if img.format == "JPEG":
|
||||
img.save(buf, "JPEG", quality=80)
|
||||
else:
|
||||
if img.mode in ("RGBA", "P"):
|
||||
img = img.convert("RGB")
|
||||
img.save(buf, "JPEG", quality=80)
|
||||
# logger.info("图片处理完成!!!")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
async def _post(self, session, url, payload):
|
||||
params = {"key": self.adapter.auth_key}
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200 or data.get("Code") != 200:
|
||||
logger.error(f"{url} failed: {resp.status} {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"{url} error: {e}")
|
||||
|
||||
|
||||
# TODO: 添加对其他消息组件类型的处理 (Record, Video, At等)
|
||||
# elif isinstance(component, Record):
|
||||
# pass
|
||||
# elif isinstance(component, Video):
|
||||
# pass
|
||||
# elif isinstance(component, At):
|
||||
# pass
|
||||
# ...
|
||||
160
astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
Normal file
160
astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Plain,
|
||||
Image,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
is_private_chat: bool = False,
|
||||
cached_texts=None,
|
||||
cached_images=None,
|
||||
raw_message: dict = None,
|
||||
downloader=None,
|
||||
):
|
||||
self._xml = None
|
||||
self.content = content
|
||||
self.is_private_chat = is_private_chat
|
||||
self.cached_texts = cached_texts or {}
|
||||
self.cached_images = cached_images or {}
|
||||
self.downloader = downloader
|
||||
|
||||
raw_message = raw_message or {}
|
||||
self.from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
self.to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
self.msg_id = raw_message.get("msg_id", "")
|
||||
|
||||
def _format_to_xml(self):
|
||||
if self._xml:
|
||||
return self._xml
|
||||
|
||||
try:
|
||||
msg_str = self.content
|
||||
if not self.is_private_chat:
|
||||
parts = self.content.split(":\n", 1)
|
||||
msg_str = parts[1] if len(parts) == 2 else self.content
|
||||
|
||||
self._xml = eT.fromstring(msg_str)
|
||||
return self._xml
|
||||
except Exception as e:
|
||||
logger.error(f"[XML解析失败] {e}")
|
||||
raise
|
||||
|
||||
async def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
"""
|
||||
处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57)
|
||||
"""
|
||||
try:
|
||||
appmsg_type = self._format_to_xml().findtext(".//appmsg/type")
|
||||
if appmsg_type == "57":
|
||||
return await self.parse_reply()
|
||||
except Exception as e:
|
||||
logger.warning(f"[parse_mutil_49] 解析失败: {e}")
|
||||
return None
|
||||
|
||||
async def parse_reply(self) -> list[BaseMessageComponent]:
|
||||
"""
|
||||
处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49)
|
||||
"""
|
||||
components = []
|
||||
|
||||
try:
|
||||
appmsg = self._format_to_xml().find("appmsg")
|
||||
if appmsg is None:
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
refermsg = appmsg.find("refermsg")
|
||||
if refermsg is None:
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
quote_type = int(refermsg.findtext("type", "0"))
|
||||
nickname = refermsg.findtext("displayname", "未知发送者")
|
||||
quote_content = refermsg.findtext("content", "")
|
||||
svrid = refermsg.findtext("svrid")
|
||||
|
||||
match quote_type:
|
||||
case 1: # 文本引用
|
||||
quoted_text = self.cached_texts.get(str(svrid), quote_content)
|
||||
components.append(Plain(f"[引用] {nickname}: {quoted_text}"))
|
||||
|
||||
case 3: # 图片引用
|
||||
quoted_image_b64 = self.cached_images.get(str(svrid))
|
||||
if not quoted_image_b64:
|
||||
try:
|
||||
quote_xml = eT.fromstring(quote_content)
|
||||
img = quote_xml.find("img")
|
||||
cdn_url = (
|
||||
img.get("cdnbigimgurl") or img.get("cdnmidimgurl")
|
||||
if img is not None
|
||||
else None
|
||||
)
|
||||
if cdn_url and self.downloader:
|
||||
image_resp = await self.downloader(
|
||||
self.from_user_name, self.to_user_name, self.msg_id
|
||||
)
|
||||
quoted_image_b64 = (
|
||||
image_resp.get("Data", {})
|
||||
.get("Data", {})
|
||||
.get("Buffer")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}")
|
||||
|
||||
if quoted_image_b64:
|
||||
components.extend(
|
||||
[
|
||||
Image.fromBase64(quoted_image_b64),
|
||||
Plain(f"[引用] {nickname}: [引用的图片]"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
components.append(
|
||||
Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]")
|
||||
)
|
||||
|
||||
case 49: # 嵌套引用
|
||||
try:
|
||||
nested_root = eT.fromstring(quote_content)
|
||||
nested_title = nested_root.findtext(".//appmsg/title", "")
|
||||
components.append(Plain(f"[引用] {nickname}: {nested_title}"))
|
||||
except Exception as e:
|
||||
logger.warning(f"[嵌套引用解析失败] err={e}")
|
||||
components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]"))
|
||||
|
||||
case _: # 其他未识别类型
|
||||
logger.info(f"[未知引用类型] quote_type={quote_type}")
|
||||
components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]"))
|
||||
|
||||
# 主消息标题
|
||||
title = appmsg.findtext("title", "")
|
||||
if title:
|
||||
components.append(Plain(title))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[parse_reply] 总体解析失败: {e}")
|
||||
return [Plain("[引用消息解析失败]")]
|
||||
|
||||
return components
|
||||
|
||||
def parse_emoji(self) -> Emoji | None:
|
||||
"""
|
||||
处理 msg_type == 47 的表情消息(emoji)
|
||||
"""
|
||||
try:
|
||||
emoji_element = self._format_to_xml().find(".//emoji")
|
||||
if emoji_element is not None:
|
||||
return Emoji(
|
||||
md5=emoji_element.get("md5"),
|
||||
md5_len=emoji_element.get("len"),
|
||||
cdnurl=emoji_element.get("cdnurl"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[parse_emoji] 解析失败: {e}")
|
||||
|
||||
return None
|
||||
@@ -1,31 +1,31 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import quart
|
||||
import aiohttp
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
from wechatpy.enterprise import WeChatClient, parse_message
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise.messages import ImageMessage, TextMessage, VoiceMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.messages import BaseMessage
|
||||
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Image, Plain, Record
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core import logger
|
||||
from requests import Response
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.messages import BaseMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.enterprise import parse_message
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
|
||||
from .wecom_kf import WeChatKF
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
@@ -146,7 +146,7 @@ class WecomPlatformAdapter(Platform):
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
@@ -257,14 +257,15 @@ class WecomPlatformAdapter(Platform):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, msg.media_id
|
||||
)
|
||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
@@ -296,11 +297,13 @@ class WecomPlatformAdapter(Platform):
|
||||
external_userid = msg.get("external_userid", None)
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
abm.self_id = msg["open_kfid"]
|
||||
abm.sender = MessageMember(external_userid, external_userid)
|
||||
abm.session_id = external_userid
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
|
||||
abm.message_str = ""
|
||||
if msgtype == "text":
|
||||
text = msg.get("text", {}).get("content", "").strip()
|
||||
abm.message = [Plain(text=text)]
|
||||
@@ -314,7 +317,29 @@ class WecomPlatformAdapter(Platform):
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
abm.message = [Image(file=path, url=path)]
|
||||
abm.message_str = "[图片]"
|
||||
elif msgtype == "voice":
|
||||
media_id = msg.get("voice", {}).get("media_id", "")
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, media_id
|
||||
)
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav")
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。")
|
||||
path_wav = path
|
||||
return
|
||||
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
else:
|
||||
logger.warning(f"未实现的微信客服消息事件: {msg}")
|
||||
return
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -7,6 +8,7 @@ from wechatpy.enterprise import WeChatClient
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import pydub
|
||||
@@ -118,6 +120,30 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传语音失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.info(f"微信客服上传语音返回: {response}")
|
||||
kf_message_api.send_voice(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
else:
|
||||
@@ -152,7 +178,8 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ from requests import Response
|
||||
from wechatpy.utils import check_signature
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy import parse_message
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
@@ -87,7 +87,11 @@ class WecomServer:
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
await self.callback(msg)
|
||||
result_xml = await self.callback(msg)
|
||||
if not result_xml:
|
||||
return "success"
|
||||
if isinstance(result_xml, str):
|
||||
return result_xml
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -117,6 +121,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.api_base_url = platform_config.get(
|
||||
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
|
||||
)
|
||||
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||
@@ -136,9 +141,31 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
async def callback(msg):
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
self.wexin_event_workers: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
try:
|
||||
await self.convert_message(msg)
|
||||
if self.active_send_mode:
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息时出现异常: {e}")
|
||||
|
||||
@@ -161,7 +188,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(self, msg) -> AstrBotMessage | None:
|
||||
async def convert_message(
|
||||
self, msg, future: asyncio.Future = None
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
@@ -175,7 +204,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
@@ -189,7 +217,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
@@ -207,7 +234,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
logger.error(f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。")
|
||||
logger.error(
|
||||
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
|
||||
)
|
||||
path_wav = path
|
||||
return
|
||||
|
||||
@@ -222,11 +251,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
future.set_result(None)
|
||||
return
|
||||
|
||||
# 很不优雅 :(
|
||||
abm.raw_message = {
|
||||
"message": msg,
|
||||
"future": future,
|
||||
"active_send_mode": self.active_send_mode,
|
||||
}
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import TextReply, ImageReply, VoiceReply
|
||||
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
@@ -82,12 +84,23 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||
if active_send_mode:
|
||||
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
@@ -102,10 +115,22 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
@@ -124,10 +149,23 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
)
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -155,7 +156,9 @@ class ProviderRequest:
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
|
||||
@@ -4,6 +4,7 @@ import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
@@ -12,12 +13,21 @@ from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
@@ -94,7 +104,10 @@ class MCPClient:
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
@@ -106,15 +119,41 @@ class MCPClient:
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
|
||||
if "url" in cfg:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(url=cfg["url"])
|
||||
streams = await self._streams_context.__aenter__()
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
# self.session = await self._session_context.__aenter__()
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
@@ -238,8 +277,7 @@ class FuncCall:
|
||||
}
|
||||
```
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||
if not os.path.exists(mcp_json_file):
|
||||
|
||||
@@ -18,13 +18,6 @@ class ProviderManager:
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
self.selected_provider_id = sp.get("curr_provider")
|
||||
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
self.provider_enabled = self.provider_settings.get("enable", False)
|
||||
self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
@@ -98,15 +91,18 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
self.curr_provider_inst: Provider = None
|
||||
"""当前使用的 Provider 实例"""
|
||||
"""默认的 Provider 实例"""
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
"""当前使用的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
"""当前使用的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
@@ -115,18 +111,57 @@ class ProviderManager:
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo and self.provider_settings["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
# 设置默认提供商
|
||||
self.curr_provider_inst = self.inst_map.get(
|
||||
self.provider_settings.get("default_provider_id")
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
if self.stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
self.curr_stt_provider_inst = self.inst_map.get(
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
self.curr_tts_provider_inst = self.inst_map.get(
|
||||
self.provider_tts_settings.get("provider_id")
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
@@ -190,6 +225,10 @@ class ProviderManager:
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
@@ -202,6 +241,26 @@ class ProviderManager:
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -234,14 +293,14 @@ class ProviderManager:
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_stt_provider_id == provider_config["id"]
|
||||
and self.stt_enabled
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_enabled:
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
@@ -254,15 +313,12 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_tts_provider_id == provider_config["id"]
|
||||
and self.tts_enabled
|
||||
):
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_enabled:
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
@@ -280,16 +336,24 @@ class ProviderManager:
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_provider_id == provider_config["id"]
|
||||
and self.provider_enabled
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_enabled:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -310,39 +374,24 @@ class ProviderManager:
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif (
|
||||
self.curr_provider_inst is None
|
||||
and len(self.provider_insts) > 0
|
||||
and self.provider_enabled
|
||||
):
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None
|
||||
and len(self.stt_provider_insts) > 0
|
||||
and self.stt_enabled
|
||||
):
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None
|
||||
and len(self.tts_provider_insts) > 0
|
||||
and self.tts_enabled
|
||||
):
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
|
||||
@@ -179,3 +179,25 @@ class TTSProvider(AbstractProvider):
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
@@ -104,11 +104,13 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
if not prompt:
|
||||
prompt = "<image>"
|
||||
|
||||
|
||||
210
astrbot/core/provider/sources/azure_tts_source.py
Normal file
210
astrbot/core/provider/sources/azure_tts_source.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import hashlib
|
||||
import random
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from httpx import AsyncClient, Timeout
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
TEMP_DIR = Path("data/temp/azure_tts")
|
||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
class OTTSProvider:
|
||||
def __init__(self, config: Dict):
|
||||
self.skey = config["OTTS_SKEY"]
|
||||
self.api_url = config["OTTS_URL"]
|
||||
self.auth_time_url = config["OTTS_AUTH_TIME"]
|
||||
self.time_offset = 0
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.client = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(timeout=self.timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
async def _sync_time(self):
|
||||
try:
|
||||
response = await self.client.get(self.auth_time_url)
|
||||
response.raise_for_status()
|
||||
server_time = int(response.json()["timestamp"])
|
||||
local_time = int(time.time())
|
||||
self.time_offset = server_time - local_time
|
||||
self.last_sync_time = local_time
|
||||
except Exception as e:
|
||||
if time.time() - self.last_sync_time > 3600:
|
||||
raise RuntimeError("时间同步失败") from e
|
||||
|
||||
async def _generate_signature(self) -> str:
|
||||
await self._sync_time()
|
||||
timestamp = int(time.time()) + self.time_offset
|
||||
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
|
||||
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
||||
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
||||
|
||||
async def get_audio(self, text: str, voice_params: Dict) -> str:
|
||||
file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav"
|
||||
signature = await self._generate_signature()
|
||||
for attempt in range(self.retry_count):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.api_url}?sign={signature}",
|
||||
data={
|
||||
"text": text,
|
||||
"voice": voice_params["voice"],
|
||||
"style": voice_params["style"],
|
||||
"role": voice_params["role"],
|
||||
"rate": voice_params["rate"],
|
||||
"volume": voice_params["volume"]
|
||||
},
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"UAK": "AstrBot/AzureTTS"
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
async for chunk in response.aiter_bytes(4096):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
except Exception as e:
|
||||
if attempt == self.retry_count - 1:
|
||||
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
|
||||
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||
raise ValueError("无效的Azure订阅密钥")
|
||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
self.client = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
self.voice_params = {
|
||||
"voice": provider_config.get("azure_tts_voice", "zh-CN-YunxiaNeural"),
|
||||
"style": provider_config.get("azure_tts_style", "cheerful"),
|
||||
"role": provider_config.get("azure_tts_role", "Boy"),
|
||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||
"volume": provider_config.get("azure_tts_volume", "100")
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
|
||||
})
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
||||
response = await self.client.post(
|
||||
token_url,
|
||||
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.token = response.text
|
||||
self.token_expire = time.time() + 540
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
if not self.token or time.time() > self.token_expire:
|
||||
await self._refresh_token()
|
||||
file_path = TEMP_DIR / f"azure-{uuid.uuid4()}.wav"
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis'
|
||||
xmlns:mstts='http://www.w3.org/2001/mstts' xml:lang='zh-CN'>
|
||||
<voice name='{escape(self.voice_params["voice"])}'>
|
||||
<mstts:express-as style='{escape(self.voice_params["style"])}'
|
||||
role='{escape(self.voice_params["role"])}'>
|
||||
<prosody rate='{escape(self.voice_params["rate"])}'
|
||||
volume='{escape(self.voice_params["volume"])}'>
|
||||
{escape(text)}
|
||||
</prosody>
|
||||
</mstts:express-as>
|
||||
</voice>
|
||||
</speak>"""
|
||||
response = await self.client.post(
|
||||
self.endpoint,
|
||||
content=ssml,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"User-Agent": f"AstrBot/{VERSION}"
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
for chunk in response.iter_bytes(4096):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
|
||||
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
||||
class AzureTTSProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||
self.provider = self._parse_provider(key_value, provider_config)
|
||||
|
||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
||||
if key_value.lower().startswith("other["):
|
||||
try:
|
||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError("无效的other[...]格式,应形如 other[{...}]")
|
||||
json_str = match.group(1).strip()
|
||||
otts_config = json.loads(json_str)
|
||||
required = {"OTTS_SKEY", "OTTS_URL", "OTTS_AUTH_TIME"}
|
||||
if missing := required - otts_config.keys():
|
||||
raise ValueError(f"缺少OTTS参数: {', '.join(missing)}")
|
||||
return OTTSProvider(otts_config)
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = (
|
||||
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
||||
f"错误详情: {e.msg}\n"
|
||||
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except KeyError as e:
|
||||
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
||||
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
|
||||
return AzureNativeProvider(config, self.provider_settings)
|
||||
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
if isinstance(self.provider, OTTSProvider):
|
||||
async with self.provider as provider:
|
||||
return await provider.get_audio(
|
||||
text,
|
||||
{
|
||||
"voice": self.provider_config.get("azure_tts_voice"),
|
||||
"style": self.provider_config.get("azure_tts_style"),
|
||||
"role": self.provider_config.get("azure_tts_role"),
|
||||
"rate": self.provider_config.get("azure_tts_rate"),
|
||||
"volume": self.provider_config.get("azure_tts_volume")
|
||||
}
|
||||
)
|
||||
else:
|
||||
async with self.provider as provider:
|
||||
return await provider.get_audio(text)
|
||||
@@ -74,6 +74,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
@@ -5,6 +6,7 @@ from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -24,7 +26,8 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
voice=self.voice,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
@@ -60,13 +61,16 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
result = ""
|
||||
session_id = session_id or kwargs.get("user") # 1734
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
@@ -227,7 +231,8 @@ class ProviderDify(Provider):
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
path = f"data/temp/{item['filename']}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
"""
|
||||
edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库
|
||||
@@ -40,9 +41,9 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
@@ -6,6 +7,7 @@ from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
@@ -87,7 +89,8 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
|
||||
63
astrbot/core/provider/sources/gemini_embedding_source.py
Normal file
63
astrbot/core/provider/sources/gemini_embedding_source.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gemini_embedding",
|
||||
"Google Gemini Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
api_key: str = provider_config.get("embedding_api_key")
|
||||
api_base: str = provider_config.get("embedding_api_base", None)
|
||||
timeout: int = int(provider_config.get("timeout", 20))
|
||||
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
if api_base:
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base[:-1]
|
||||
http_options.base_url = api_base
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
|
||||
self.model = provider_config.get(
|
||||
"embedding_model", "gemini-embedding-exp-03-07"
|
||||
)
|
||||
self.dimension = provider_config.get("embedding_dimensions", 768)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model, contents=text
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model, contents=texts
|
||||
)
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -141,24 +141,66 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
|
||||
tool_list = None
|
||||
tool_list = []
|
||||
model_name = self.get_model()
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
|
||||
if native_coderunner:
|
||||
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
|
||||
if native_search:
|
||||
logger.warning("已启用代码执行工具,搜索工具将被忽略")
|
||||
if tools:
|
||||
logger.warning("已启用代码执行工具,函数工具将被忽略")
|
||||
elif native_search:
|
||||
tool_list = [types.Tool(google_search=types.GoogleSearch())]
|
||||
if tools:
|
||||
logger.warning("已启用搜索工具,函数工具将被忽略")
|
||||
if "gemini-2.5" in model_name:
|
||||
if native_coderunner:
|
||||
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
||||
if native_search:
|
||||
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
||||
if url_context:
|
||||
logger.warning(
|
||||
"代码执行工具与URL上下文工具互斥,已忽略URL上下文工具"
|
||||
)
|
||||
else:
|
||||
if native_search:
|
||||
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
|
||||
if url_context:
|
||||
if hasattr(types, "UrlContext"):
|
||||
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
||||
else:
|
||||
logger.warning(
|
||||
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
|
||||
)
|
||||
|
||||
elif "gemini-2.0-lite" in model_name:
|
||||
if native_coderunner or native_search or url_context:
|
||||
logger.warning(
|
||||
"gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文,将忽略这些设置"
|
||||
)
|
||||
tool_list = None
|
||||
|
||||
else:
|
||||
if native_coderunner:
|
||||
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
|
||||
if native_search:
|
||||
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
|
||||
elif native_search:
|
||||
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
|
||||
if url_context and not native_coderunner:
|
||||
if hasattr(types, "UrlContext"):
|
||||
tool_list.append(types.Tool(url_context=types.UrlContext()))
|
||||
else:
|
||||
logger.warning(
|
||||
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
|
||||
)
|
||||
|
||||
if not tool_list:
|
||||
tool_list = None
|
||||
|
||||
if tools and tool_list:
|
||||
logger.warning("已启用原生工具,函数工具将被忽略")
|
||||
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
@@ -189,6 +231,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
@@ -290,19 +333,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
||||
|
||||
if finish_reason in {
|
||||
types.FinishReason.PROHIBITED_CONTENT,
|
||||
types.FinishReason.SPII,
|
||||
types.FinishReason.BLOCKLIST,
|
||||
}:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
# 防止旧版本SDK不存在IMAGE_SAFETY
|
||||
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
||||
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
if not result_parts:
|
||||
logger.debug(result.candidates)
|
||||
|
||||
148
astrbot/core/provider/sources/gsv_selfhosted_source.py
Normal file
148
astrbot/core/provider/sources/gsv_selfhosted_source.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
provider_type_name="gsv_tts_selfhost",
|
||||
desc="GPT-SoVITS TTS(本地加载)",
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
)
|
||||
class ProviderGSVTTS(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
|
||||
"/"
|
||||
)
|
||||
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
|
||||
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
|
||||
|
||||
# TTS 请求的默认参数,移除前缀gsv_
|
||||
self.default_params: dict = {
|
||||
key.removeprefix("gsv_"): str(value).lower()
|
||||
for key, value in provider_config.get("gsv_default_parms", {}).items()
|
||||
}
|
||||
self.timeout = provider_config.get("timeout", 60)
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化:在 ProviderManager 中被调用"""
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
)
|
||||
try:
|
||||
await self._set_model_weights()
|
||||
logger.info("[GSV TTS] 初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[GSV TTS] 初始化失败:{e}")
|
||||
raise
|
||||
|
||||
def get_session(self) -> aiohttp.ClientSession:
|
||||
if not self._session or self._session.closed:
|
||||
raise RuntimeError(
|
||||
"[GSV TTS] Provider HTTP session is not ready or closed."
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def _make_request(
|
||||
self, endpoint: str, params=None, retries: int = 3
|
||||
) -> bytes | None:
|
||||
"""发起请求"""
|
||||
for attempt in range(retries):
|
||||
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
|
||||
try:
|
||||
async with self.get_session().get(endpoint, params=params) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
|
||||
)
|
||||
return await response.read()
|
||||
except Exception as e:
|
||||
if attempt < retries - 1:
|
||||
logger.warning(
|
||||
f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
|
||||
raise
|
||||
|
||||
async def _set_model_weights(self):
|
||||
"""设置模型路径"""
|
||||
try:
|
||||
if self.gpt_weights_path:
|
||||
await self._make_request(
|
||||
f"{self.api_base}/set_gpt_weights",
|
||||
{"weights_path": self.gpt_weights_path},
|
||||
)
|
||||
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
|
||||
else:
|
||||
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
|
||||
|
||||
if self.sovits_weights_path:
|
||||
await self._make_request(
|
||||
f"{self.api_base}/set_sovits_weights",
|
||||
{"weights_path": self.sovits_weights_path},
|
||||
)
|
||||
logger.info(
|
||||
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
|
||||
)
|
||||
else:
|
||||
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
|
||||
if not text.strip():
|
||||
raise ValueError("[GSV TTS] TTS 文本不能为空")
|
||||
|
||||
endpoint = f"{self.api_base}/tts"
|
||||
|
||||
params = self.build_synthesis_params(text)
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
|
||||
|
||||
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
|
||||
|
||||
result = await self._make_request(endpoint, params)
|
||||
if isinstance(result, bytes):
|
||||
with open(path, "wb") as f:
|
||||
f.write(result)
|
||||
return path
|
||||
else:
|
||||
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
|
||||
|
||||
def build_synthesis_params(self, text: str) -> dict:
|
||||
"""
|
||||
构建语音合成所需的参数字典。
|
||||
|
||||
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
|
||||
"""
|
||||
params = self.default_params.copy()
|
||||
params["text"] = text
|
||||
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
|
||||
return params
|
||||
|
||||
async def terminate(self):
|
||||
"""终止释放资源:在 ProviderManager 中被调用"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
logger.info("[GSV TTS] Session 已关闭")
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import urllib.parse
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -23,7 +25,8 @@ class ProviderGSVITTS(TTSProvider):
|
||||
self.emotion = provider_config.get("emotion")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
params = {"text": text}
|
||||
|
||||
if self.character:
|
||||
|
||||
@@ -60,10 +60,12 @@ class LLMTunerModelLoader(Provider):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = [],
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
149
astrbot/core/provider/sources/minimax_tts_api_source.py
Normal file
149
astrbot/core/provider/sources/minimax_tts_api_source.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from typing import Dict, List, Union, AsyncIterator
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.api import logger
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"minimax_tts_api", "MiniMax TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.api_base: str = provider_config.get(
|
||||
"api_base", "https://api.minimax.chat/v1/t2a_v2"
|
||||
)
|
||||
self.group_id: str = provider_config.get("minimax-group-id", "")
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
|
||||
self.is_timber_weight: bool = provider_config.get(
|
||||
"minimax-is-timber-weight", False
|
||||
)
|
||||
self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads(
|
||||
provider_config.get(
|
||||
"minimax-timber-weight",
|
||||
'[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
|
||||
)
|
||||
)
|
||||
|
||||
self.voice_setting: dict = {
|
||||
"speed": provider_config.get("minimax-voice-speed", 1.0),
|
||||
"vol": provider_config.get("minimax-voice-vol", 1.0),
|
||||
"pitch": provider_config.get("minimax-voice-pitch", 0),
|
||||
"voice_id": ""
|
||||
if self.is_timber_weight
|
||||
else provider_config.get("minimax-voice-id", ""),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
||||
"latex_read": provider_config.get("minimax-voice-latex", False),
|
||||
"english_normalization": provider_config.get(
|
||||
"minimax-voice-english-normalization", False
|
||||
),
|
||||
}
|
||||
|
||||
self.audio_setting: dict = {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
}
|
||||
|
||||
self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
def _build_tts_stream_body(self, text: str):
|
||||
"""构建流式请求体"""
|
||||
dict_body: Dict[str, object] = {
|
||||
"model": self.model_name,
|
||||
"text": text,
|
||||
"stream": True,
|
||||
"language_boost": self.lang_boost,
|
||||
"voice_setting": self.voice_setting,
|
||||
"audio_setting": self.audio_setting,
|
||||
}
|
||||
if self.is_timber_weight:
|
||||
dict_body["timber_weights"] = self.timber_weight
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.concat_base_url,
|
||||
headers=self.headers,
|
||||
data=self._build_tts_stream_body(text),
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
buffer = b""
|
||||
while True:
|
||||
chunk = await response.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk
|
||||
|
||||
while b"\n\n" in buffer:
|
||||
try:
|
||||
message, buffer = buffer.split(b"\n\n", 1)
|
||||
if message.startswith(b"data: "):
|
||||
try:
|
||||
data = json.loads(message[6:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse JSON data from SSE message"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
buffer = buffer[-1024:]
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
|
||||
|
||||
async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
|
||||
"""解码数据流到 audio 比特流"""
|
||||
chunks = []
|
||||
async for chunk in audio_stream:
|
||||
if chunk.strip():
|
||||
chunks.append(bytes.fromhex(chunk.strip()))
|
||||
return b"".join(chunks)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
|
||||
|
||||
try:
|
||||
# 直接将异步生成器传递给 _audio_play 方法
|
||||
audio_stream = self._call_tts_stream(text)
|
||||
audio = await self._audio_play(audio_stream)
|
||||
|
||||
# 结果保存至文件
|
||||
with open(path, "wb") as file:
|
||||
file.write(audio)
|
||||
|
||||
return path
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise e
|
||||
43
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
43
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from openai import AsyncOpenAI
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"openai_embedding",
|
||||
"OpenAI API Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
base_url=provider_config.get(
|
||||
"embedding_api_base", "https://api.openai.com/v1"
|
||||
),
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -195,7 +195,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
# workaround for #1454
|
||||
if isinstance(tool_call.function.arguments, str):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
else:
|
||||
args = tool_call.function.arguments
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
tool_call_ids.append(tool_call.id)
|
||||
@@ -223,9 +227,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: list=None,
|
||||
system_prompt: str=None,
|
||||
tool_calls_result: ToolCallsResult=None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -340,9 +344,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt,
|
||||
session_id = None,
|
||||
image_urls = None,
|
||||
func_tool = None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -31,7 +33,8 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name, voice=self.voice, response_format="wav", input=text
|
||||
) as response:
|
||||
|
||||
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
|
||||
@register_provider_adapter(
|
||||
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderVolcengineTTS(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.api_key = provider_config.get("api_key", "")
|
||||
self.appid = provider_config.get("appid", "")
|
||||
self.cluster = provider_config.get("volcengine_cluster", "")
|
||||
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
||||
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
||||
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
|
||||
def _build_request_payload(self, text: str) -> dict:
|
||||
return {
|
||||
"app": {
|
||||
"appid": self.appid,
|
||||
"token": self.api_key,
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4())
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
"encoding": "mp3",
|
||||
"speed_ratio": self.speed_ratio,
|
||||
"volume_ratio": 1.0,
|
||||
"pitch_ratio": 1.0,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "query",
|
||||
"with_frontend": 1,
|
||||
"frontend_type": "unitTson"
|
||||
}
|
||||
}
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""异步方法获取语音文件路径"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer; {self.api_key}"
|
||||
}
|
||||
|
||||
payload = self._build_request_payload(text)
|
||||
|
||||
logger.debug(f"请求头: {headers}")
|
||||
logger.debug(f"请求 URL: {self.api_base}")
|
||||
logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.api_base,
|
||||
data=json.dumps(payload),
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
logger.debug(f"响应状态码: {response.status}")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(f"响应内容: {response_text[:200]}...")
|
||||
|
||||
if response.status == 200:
|
||||
resp_data = json.loads(response_text)
|
||||
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data)
|
||||
)
|
||||
|
||||
return file_path
|
||||
else:
|
||||
error_msg = resp_data.get("message", "未知错误")
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
||||
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -50,7 +51,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -61,7 +63,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -53,7 +54,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -64,7 +66,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -31,10 +31,12 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class SimpleOpenAIEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
@@ -1,94 +0,0 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
|
||||
class KnowledgeDBManager:
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(
|
||||
name, cfg["embedding_config"]
|
||||
)
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [
|
||||
f
|
||||
for f in os.listdir(self.db_path)
|
||||
if os.path.isfile(os.path.join(self.db_path, f))
|
||||
]
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
"""
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]["strategy"]:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class Store:
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
@@ -1,42 +0,0 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path="data/long_term_memory_chroma.db"
|
||||
)
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None),
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding,
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, query: str, top_n=3, metadata_filter: Dict = None
|
||||
) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding, n_results=top_n, where=metadata_filter
|
||||
)
|
||||
return results["documents"][0]
|
||||
@@ -5,6 +5,7 @@
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def load_config(namespace: str) -> Union[dict, bool]:
|
||||
@@ -13,7 +14,7 @@ def load_config(namespace: str) -> Union[dict, bool]:
|
||||
namespace: str, 配置的唯一识别符,也就是配置文件的名字。
|
||||
返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
@@ -43,7 +44,10 @@ def put_config(namespace: str, name: str, key: str, value, description: str):
|
||||
raise ValueError("key 只支持 str 类型。")
|
||||
if not isinstance(value, (str, int, float, bool, list)):
|
||||
raise ValueError("value 只支持 str, int, float, bool, list 类型。")
|
||||
path = f"data/config/{namespace}.json"
|
||||
|
||||
config_dir = os.path.join(get_astrbot_data_path(), "config")
|
||||
path = os.path.join(config_dir, f"{namespace}.json")
|
||||
|
||||
if not os.path.exists(path):
|
||||
with open(path, "w", encoding="utf-8-sig") as f:
|
||||
f.write("{}")
|
||||
@@ -71,7 +75,7 @@ def update_config(namespace: str, key: str, value):
|
||||
key: str, 配置项的键。
|
||||
value: str, int, float, bool, list, 配置项的值。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。")
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -16,7 +17,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
@@ -42,6 +42,8 @@ class Context:
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
@@ -54,14 +56,12 @@ class Context:
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
@@ -126,11 +126,8 @@ class Context:
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
@@ -144,24 +141,46 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_using_tts_provider(self) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
|
||||
def get_using_stt_provider(self) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
@@ -301,3 +320,12 @@ class Context:
|
||||
注册一个异步任务。
|
||||
"""
|
||||
self._register_tasks.append(task)
|
||||
|
||||
def register_web_api(
|
||||
self, route: str, view_handler: Awaitable, methods: list, desc: str
|
||||
):
|
||||
for idx, api in enumerate(self.registered_web_apis):
|
||||
if api[0] == route and methods == api[2]:
|
||||
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
|
||||
return
|
||||
self.registered_web_apis.append((route, view_handler, methods, desc))
|
||||
|
||||
@@ -7,6 +7,9 @@ from astrbot.core.config import AstrBotConfig
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
class GreedyStr(str):
|
||||
"""标记指令完成其他参数接收后的所有剩余文本。"""
|
||||
pass
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter):
|
||||
@@ -68,7 +71,22 @@ class CommandFilter(HandlerFilter):
|
||||
) -> Dict[str, Any]:
|
||||
"""将参数列表 params 根据 param_type 转换为参数字典。"""
|
||||
result = {}
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
param_items = list(param_type.items())
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_items):
|
||||
is_greedy = param_type_or_default_val is GreedyStr
|
||||
|
||||
if is_greedy:
|
||||
# GreedyStr 必须是最后一个参数
|
||||
if i != len(param_items) - 1:
|
||||
raise ValueError(
|
||||
f"参数 '{param_name}' (GreedyStr) 必须是最后一个参数。"
|
||||
)
|
||||
|
||||
# 将剩余的所有部分合并成一个字符串
|
||||
remaining_params = params[i:]
|
||||
result[param_name] = " ".join(remaining_params)
|
||||
break
|
||||
# 没有 GreedyStr 的情况
|
||||
if i >= len(params):
|
||||
if (
|
||||
isinstance(param_type_or_default_val, Type)
|
||||
|
||||
@@ -113,7 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
)
|
||||
raise ValueError(
|
||||
f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||
+ tree
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import heapq
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
@@ -8,100 +7,66 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
"""用于存储所有的 Star Handler"""
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
"""用于快速查找。key 是 handler_full_name"""
|
||||
_handlers = []
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
self._handlers: List[StarHandlerMetadata] = []
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
"""添加一个 Handler"""
|
||||
"""添加一个 Handler,并保持按优先级有序"""
|
||||
if "priority" not in handler.extras_configs:
|
||||
handler.extras_configs["priority"] = 0
|
||||
|
||||
heapq.heappush(self._handlers, (-handler.extras_configs["priority"], handler))
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
self._handlers.append(handler)
|
||||
self._handlers.sort(key=lambda h: -h.extras_configs["priority"])
|
||||
|
||||
def _print_handlers(self):
|
||||
"""打印所有的 Handler"""
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过事件类型获取 Handler
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
only_activated: 是否只返回已激活的插件的处理器
|
||||
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||
|
||||
Returns:
|
||||
List[StarHandlerMetadata]: 处理器列表
|
||||
"""
|
||||
handlers = []
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
|
||||
# 只激活的插件处理器
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
|
||||
# 平台兼容性过滤
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
continue
|
||||
|
||||
handlers.append(handler)
|
||||
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
"""通过 Handler 的全名获取 Handler"""
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过模块名获取 Handler"""
|
||||
return [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
handler for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
def clear(self):
|
||||
"""清空所有的 Handler"""
|
||||
self.star_handlers_map.clear()
|
||||
self._handlers.clear()
|
||||
|
||||
def remove(self, handler: StarHandlerMetadata):
|
||||
"""删除一个 Handler"""
|
||||
# self._handlers.remove(handler)
|
||||
for i, h in enumerate(self._handlers):
|
||||
if h[1] == handler:
|
||||
self._handlers.pop(i)
|
||||
break
|
||||
try:
|
||||
del self.star_handlers_map[handler.handler_full_name]
|
||||
except KeyError:
|
||||
pass
|
||||
self.star_handlers_map.pop(handler.handler_full_name, None)
|
||||
self._handlers = [h for h in self._handlers if h != handler]
|
||||
|
||||
def __iter__(self):
|
||||
"""使 StarHandlerRegistry 支持迭代"""
|
||||
return (handler for _, handler in self._handlers)
|
||||
return iter(self._handlers)
|
||||
|
||||
def __len__(self):
|
||||
"""返回 Handler 的数量"""
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
|
||||
@@ -2,28 +2,46 @@
|
||||
插件的重载、启停、安装、卸载等操作。
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
import asyncio
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
import yaml
|
||||
|
||||
from astrbot.core import logger, pip_installer, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
|
||||
from . import StarMetadata
|
||||
from .context import Context
|
||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||
from .star import star_map, star_registry
|
||||
from .star_handler import star_handlers_registry
|
||||
from .updator import PluginUpdator
|
||||
|
||||
try:
|
||||
from watchfiles import PythonFilter, awatch
|
||||
except ImportError:
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
@@ -34,17 +52,9 @@ class PluginManager:
|
||||
self.context._star_manager = self
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/config"
|
||||
)
|
||||
)
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
@@ -56,6 +66,58 @@ class PluginManager:
|
||||
"""插件配置 Schema 文件名"""
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
asyncio.create_task(self._watch_plugins_changes())
|
||||
|
||||
async def _watch_plugins_changes(self):
|
||||
"""监视插件文件变化"""
|
||||
try:
|
||||
async for changes in awatch(
|
||||
self.plugin_store_path,
|
||||
self.reserved_plugin_path,
|
||||
watch_filter=PythonFilter(),
|
||||
recursive=True,
|
||||
):
|
||||
# 处理文件变化
|
||||
await self._handle_file_changes(changes)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"插件热重载监视任务异常: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _handle_file_changes(self, changes):
|
||||
"""处理文件变化"""
|
||||
logger.info(f"检测到文件变化: {changes}")
|
||||
plugins_to_check = []
|
||||
|
||||
for star in star_registry:
|
||||
if not star.activated:
|
||||
continue
|
||||
if star.root_dir_name is None:
|
||||
continue
|
||||
if star.reserved:
|
||||
plugin_dir_path = os.path.join(
|
||||
self.reserved_plugin_path, star.root_dir_name
|
||||
)
|
||||
else:
|
||||
plugin_dir_path = os.path.join(
|
||||
self.plugin_store_path, star.root_dir_name
|
||||
)
|
||||
plugins_to_check.append((plugin_dir_path, star.name))
|
||||
reloaded_plugins = set()
|
||||
for change in changes:
|
||||
_, file_path = change
|
||||
for plugin_dir_path, plugin_name in plugins_to_check:
|
||||
if (
|
||||
os.path.commonpath([plugin_dir_path])
|
||||
== os.path.commonpath([plugin_dir_path, file_path])
|
||||
and plugin_name not in reloaded_plugins
|
||||
):
|
||||
logger.info(f"检测到插件 {plugin_name} 文件变化,正在重载...")
|
||||
await self.reload(plugin_name)
|
||||
reloaded_plugins.add(plugin_name)
|
||||
break
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
||||
@@ -104,7 +166,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -123,7 +185,7 @@ class PluginManager:
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
pip_installer.install(requirements_path=pth)
|
||||
await pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
@@ -345,7 +407,7 @@ class PluginManager:
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# 尝试安装依赖
|
||||
self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
await self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -389,11 +451,11 @@ class PluginManager:
|
||||
metadata.repo = metadata_yaml.repo
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.config = plugin_config
|
||||
if path not in inactivated_plugins:
|
||||
# 只有没有禁用插件时才实例化插件类
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
# metadata.config = plugin_config
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context, config=plugin_config
|
||||
@@ -580,16 +642,21 @@ class PluginManager:
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path):
|
||||
if os.path.exists(readme_path) and nh3:
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
plugin_info = {
|
||||
"repo": plugin.repo,
|
||||
"readme": cleaned_content,
|
||||
"name": plugin.name,
|
||||
}
|
||||
|
||||
return plugin_info
|
||||
|
||||
@@ -784,6 +851,10 @@ class PluginManager:
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
plugin_info = {
|
||||
"repo": plugin.repo,
|
||||
"readme": readme_content,
|
||||
"name": plugin.name,
|
||||
}
|
||||
|
||||
return plugin_info
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -6,7 +8,7 @@ from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map
|
||||
from pathlib import Path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class StarTools:
|
||||
@@ -180,7 +182,7 @@ class StarTools:
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path("data/plugin_data") / plugin_name
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user