Compare commits
592 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da546cfe7f | ||
|
|
a211933e83 | ||
|
|
1d40b5a821 | ||
|
|
33836daeb7 | ||
|
|
0de6d0e046 | ||
|
|
9fedaa9f77 | ||
|
|
bf4c2ecd33 | ||
|
|
f8c18cc1e0 | ||
|
|
458b900412 | ||
|
|
192c776e0b | ||
|
|
5cdec18863 | ||
|
|
15f856f951 | ||
|
|
01d52cef74 | ||
|
|
31d8c40eca | ||
|
|
56001ed272 | ||
|
|
cfae655068 | ||
|
|
5596565ec4 | ||
|
|
e98c3d8393 | ||
|
|
6687b816f0 | ||
|
|
ea8035e854 | ||
|
|
54b0171d49 | ||
|
|
676d4277b9 | ||
|
|
a4b1da3ca2 | ||
|
|
9e9c16e770 | ||
|
|
dc87006fed | ||
|
|
b9b260f26a | ||
|
|
33fd6a5016 | ||
|
|
97cbccc2ba | ||
|
|
1ee4685d5d | ||
|
|
aba18232b1 | ||
|
|
0a02441b75 | ||
|
|
1be5b4c7ff | ||
|
|
a0ce0cf18a | ||
|
|
7c54e5d093 | ||
|
|
b825e51dab | ||
|
|
589855c393 | ||
|
|
4c546f2f53 | ||
|
|
3753fce912 | ||
|
|
4c02857ec5 | ||
|
|
33f87ff7d7 | ||
|
|
784dcf2a9a | ||
|
|
43ee943acb | ||
|
|
a769fd7d13 | ||
|
|
2c4fd00b16 | ||
|
|
264771fe98 | ||
|
|
ecd92dafef | ||
|
|
c8b6e4bea3 | ||
|
|
3756cb766e | ||
|
|
068d9ca60b | ||
|
|
93f632d8b8 | ||
|
|
bb44ce7e74 | ||
|
|
6986c8d8f7 | ||
|
|
fe95506db4 | ||
|
|
310ed76b18 | ||
|
|
98830d147f | ||
|
|
19c9177d7b | ||
|
|
f41c5f97f6 | ||
|
|
648c125697 | ||
|
|
0dc2b89897 | ||
|
|
83745f83a5 | ||
|
|
2f91fe4535 | ||
|
|
739f09059e | ||
|
|
c86f9f0f5f | ||
|
|
9470ca6bc5 | ||
|
|
2a92c4d5de | ||
|
|
bb6e892657 | ||
|
|
c9079b9299 | ||
|
|
b6963c1bf9 | ||
|
|
9c29df47bb | ||
|
|
fc146d3d00 | ||
|
|
1bf5a21678 | ||
|
|
011542dc2b | ||
|
|
489784104e | ||
|
|
3860634fd2 | ||
|
|
709c324e18 | ||
|
|
b75d24d92c | ||
|
|
ed80e9424c | ||
|
|
2fe1f2060a | ||
|
|
c6df820164 | ||
|
|
d6239822db | ||
|
|
bced9ffff9 | ||
|
|
d7d1c1544a | ||
|
|
e3b0ca8ef6 | ||
|
|
9e266eb6d5 | ||
|
|
7231403e16 | ||
|
|
344a486fd7 | ||
|
|
4fd831875d | ||
|
|
0988d067ea | ||
|
|
44dbe475af | ||
|
|
bd24cf3ea4 | ||
|
|
b493a808fe | ||
|
|
54035d108d | ||
|
|
c5e8bc7e20 | ||
|
|
3bbb4779a3 | ||
|
|
1b3963ebea | ||
|
|
3b6dd7e15a | ||
|
|
757d2a3947 | ||
|
|
61b71143f2 | ||
|
|
1b343a36c9 | ||
|
|
8e94937060 | ||
|
|
e8ffebc006 | ||
|
|
2ca95eaa9f | ||
|
|
0dc5b4cdfc | ||
|
|
cc6cd96d8e | ||
|
|
4244d37625 | ||
|
|
0b766095d4 | ||
|
|
a4f212a18f | ||
|
|
caafb73190 | ||
|
|
09482799c9 | ||
|
|
37f93d1760 | ||
|
|
725f2e5204 | ||
|
|
967198fae0 | ||
|
|
43d57f6dcb | ||
|
|
6afa4db577 | ||
|
|
3b8c3fb29a | ||
|
|
921c3b0627 | ||
|
|
c0fadb45ab | ||
|
|
a1481fb179 | ||
|
|
987cd972d3 | ||
|
|
bdf25976a3 | ||
|
|
87c3aff4ce | ||
|
|
99350a957a | ||
|
|
319068dc7e | ||
|
|
cd18806c39 | ||
|
|
95b08b2023 | ||
|
|
0e70f76c86 | ||
|
|
4d414a2994 | ||
|
|
3d22772d4e | ||
|
|
0b381e2570 | ||
|
|
f2cc4311c5 | ||
|
|
e349671fdf | ||
|
|
01c02d5efa | ||
|
|
b62b1f3870 | ||
|
|
8844830859 | ||
|
|
0c51ee4b64 | ||
|
|
11920d5e31 | ||
|
|
848ea1eb63 | ||
|
|
a216519486 | ||
|
|
b04606c38e | ||
|
|
38072beea7 | ||
|
|
b843f1fa03 | ||
|
|
560d40e571 | ||
|
|
5f0b8161b7 | ||
|
|
062d482917 | ||
|
|
39693a27e3 | ||
|
|
7cd1eeac30 | ||
|
|
bafa473c8e | ||
|
|
750cf46b2e | ||
|
|
68885a4bbc | ||
|
|
bcc99a8904 | ||
|
|
59fbd98db3 | ||
|
|
b70ed425f1 | ||
|
|
45ef5811c8 | ||
|
|
3b137ac762 | ||
|
|
1ddb0caf73 | ||
|
|
ae4c6fe2dd | ||
|
|
b03fe438d0 | ||
|
|
db257af58e | ||
|
|
735368c71b | ||
|
|
9e04e3679b | ||
|
|
43b8414727 | ||
|
|
5a00187147 | ||
|
|
cb525c7c84 | ||
|
|
d88420dd03 | ||
|
|
b9a983f8e0 | ||
|
|
42431ea7db | ||
|
|
f9459e4abb | ||
|
|
72f917d611 | ||
|
|
9fd1d19e93 | ||
|
|
062af1ac08 | ||
|
|
41bd76e091 | ||
|
|
cfd3f4b199 | ||
|
|
79d38f9597 | ||
|
|
b3866559e1 | ||
|
|
4d186baa35 | ||
|
|
8ed3d5f3db | ||
|
|
f0c8f39b6d | ||
|
|
431db8fc9b | ||
|
|
ba252c5356 | ||
|
|
a2812c39c0 | ||
|
|
0490758820 | ||
|
|
7f56824b42 | ||
|
|
627da3a2bc | ||
|
|
9b36a5c8a6 | ||
|
|
c1cf2be533 | ||
|
|
e6b69042de | ||
|
|
109650faf3 | ||
|
|
e54eaab842 | ||
|
|
43b6297b5d | ||
|
|
c20f4f5adf | ||
|
|
dc1f222cd2 | ||
|
|
c2b687212c | ||
|
|
849913276d | ||
|
|
23579c1e4a | ||
|
|
e031161fd4 | ||
|
|
4800ee6c0a | ||
|
|
d3a7fef9b0 | ||
|
|
40822fe77a | ||
|
|
837b670213 | ||
|
|
57ce69f3fb | ||
|
|
be022c4894 | ||
|
|
8a366964bb | ||
|
|
ee86b68470 | ||
|
|
60352307aa | ||
|
|
3ebd2f746f | ||
|
|
1c1a65b637 | ||
|
|
010e60d029 | ||
|
|
7a25568861 | ||
|
|
5f4f913661 | ||
|
|
ccd0e34a53 | ||
|
|
72f1ffccd3 | ||
|
|
ea7a52945f | ||
|
|
89d4d1351a | ||
|
|
b757c91d93 | ||
|
|
27203d7a4d | ||
|
|
9ad4e18ac5 | ||
|
|
fcdc8f3ce7 | ||
|
|
78b994b84a | ||
|
|
58bfc677e2 | ||
|
|
7d17285a0c | ||
|
|
e9eb00a0d4 | ||
|
|
48d07af574 | ||
|
|
2fc62efd88 | ||
|
|
be516d75bd | ||
|
|
951d5fde85 | ||
|
|
1389abc052 | ||
|
|
19ad67a77f | ||
|
|
641f308344 | ||
|
|
9f097fa4d5 | ||
|
|
5ad362c52b | ||
|
|
614f238a61 | ||
|
|
dec91950bc | ||
|
|
6cef9c23f0 | ||
|
|
3f568bf136 | ||
|
|
5484b421ce | ||
|
|
02f21e07d3 | ||
|
|
fff1f23a83 | ||
|
|
a056ec0d38 | ||
|
|
2eb9e5dde3 | ||
|
|
627d2a4701 | ||
|
|
76895fe86d | ||
|
|
64c3c85780 | ||
|
|
7288348857 | ||
|
|
62e73299b1 | ||
|
|
fe76c41ed8 | ||
|
|
1a92edf8be | ||
|
|
b63b606a4e | ||
|
|
8e2ef3d22b | ||
|
|
c6c4a32283 | ||
|
|
b70b3b158e | ||
|
|
3d59ab8108 | ||
|
|
b6c3089510 | ||
|
|
bd92aac280 | ||
|
|
5299e802e9 | ||
|
|
8e5a57d7dd | ||
|
|
beaa324fb6 | ||
|
|
79e64fe206 | ||
|
|
93f525e3fe | ||
|
|
aacb803c64 | ||
|
|
8a0665b222 | ||
|
|
20e41a7f73 | ||
|
|
93a1699a35 | ||
|
|
c33c07e4af | ||
|
|
c7484d0cc9 | ||
|
|
fb85a7bb35 | ||
|
|
42ff9a4d34 | ||
|
|
005e9eae7c | ||
|
|
3e325debcc | ||
|
|
a221de9a2b | ||
|
|
32b0cc1865 | ||
|
|
bbf85f8a12 | ||
|
|
67a0172b28 | ||
|
|
fb19d4d45b | ||
|
|
a156b1af14 | ||
|
|
a604b4943c | ||
|
|
3f0b6435d9 | ||
|
|
e0f029e2cb | ||
|
|
89d3fd5fab | ||
|
|
a38b00be6b | ||
|
|
0e8d52b591 | ||
|
|
298c77740d | ||
|
|
c681aae8ee | ||
|
|
faef98b089 | ||
|
|
84a3e0a30b | ||
|
|
69bd553ce0 | ||
|
|
fd0c0f8975 | ||
|
|
860ceb06b4 | ||
|
|
ecf501bf72 | ||
|
|
81a2ed1e25 | ||
|
|
76ab28338a | ||
|
|
9a56c9630f | ||
|
|
53b9497c18 | ||
|
|
750b16b6ee | ||
|
|
0ee3e0779a | ||
|
|
333c2d9299 | ||
|
|
ad37ff5048 | ||
|
|
33f86f3bde | ||
|
|
8acb969a49 | ||
|
|
b74b5933b8 | ||
|
|
681c556b7e | ||
|
|
1746684e52 | ||
|
|
0b93d06555 | ||
|
|
8a8b8c7c27 | ||
|
|
6b6577006d | ||
|
|
23ee5e81c9 | ||
|
|
483f55e4b1 | ||
|
|
1bb1bc2553 | ||
|
|
a4e4e36f94 | ||
|
|
6849415812 | ||
|
|
86f6cb038e | ||
|
|
7480a1d6ce | ||
|
|
3cd10117dd | ||
|
|
0caf19d390 | ||
|
|
5c14ebb049 | ||
|
|
9717a736b1 | ||
|
|
9c9ab50d1a | ||
|
|
d4bcb8174e | ||
|
|
9e7fe773bd | ||
|
|
aca18fab0f | ||
|
|
691de01b79 | ||
|
|
3383f15142 | ||
|
|
84c1593889 | ||
|
|
3c80fa1e33 | ||
|
|
06b16a1deb | ||
|
|
4c4246fb09 | ||
|
|
364be1e9f6 | ||
|
|
f959ed71aa | ||
|
|
5c4326c302 | ||
|
|
125fc3a622 | ||
|
|
6b9e785db3 | ||
|
|
25d34e9a43 | ||
|
|
457d4aa1dc | ||
|
|
ff0c0992ff | ||
|
|
d379e012c4 | ||
|
|
151fff26fd | ||
|
|
3d0d561215 | ||
|
|
22d586ed7b | ||
|
|
6dc19b29e8 | ||
|
|
50975a87d4 | ||
|
|
ce721d9f0f | ||
|
|
20510a33f7 | ||
|
|
3abd9c8763 | ||
|
|
e9eff7420b | ||
|
|
64c250c9d8 | ||
|
|
8047f82bfd | ||
|
|
af6467fb3d | ||
|
|
3ff1664aec | ||
|
|
34ea2b44b8 | ||
|
|
6c8d851109 | ||
|
|
d678299a74 | ||
|
|
7aed0db2b6 | ||
|
|
0355524345 | ||
|
|
0a43e4672e | ||
|
|
71e0ccdfec | ||
|
|
1df33ac3c8 | ||
|
|
7334090ac1 | ||
|
|
6b0f044198 | ||
|
|
ddf54c9cf8 | ||
|
|
7c64e184e2 | ||
|
|
a904db033c | ||
|
|
b234856b02 | ||
|
|
89d51d2afc | ||
|
|
37cb9678e9 | ||
|
|
0500ff333a | ||
|
|
08528510ef | ||
|
|
ddbd03dc1e | ||
|
|
ade87f378a | ||
|
|
4db14b905f | ||
|
|
b669b31451 | ||
|
|
1cb2b62f81 | ||
|
|
e5828713cf | ||
|
|
d10cb84068 | ||
|
|
4222f8516f | ||
|
|
7f998c7611 | ||
|
|
db46000337 | ||
|
|
1aac8d8041 | ||
|
|
c59c8e05f7 | ||
|
|
4942d0a629 | ||
|
|
873b7715f4 | ||
|
|
98e7ed6920 | ||
|
|
046f5e645e | ||
|
|
f5e5a7094c | ||
|
|
154125fee6 | ||
|
|
9f8e960ebe | ||
|
|
4179b0be0a | ||
|
|
28bafa38db | ||
|
|
b07552565e | ||
|
|
c4427471d2 | ||
|
|
08f81c6784 | ||
|
|
a471e98aca | ||
|
|
75a8fcc8a0 | ||
|
|
46ef76c168 | ||
|
|
66637446c9 | ||
|
|
21efeb888a | ||
|
|
a4ee8b5322 | ||
|
|
36519ac47e | ||
|
|
3f514fceca | ||
|
|
c2249fdfac | ||
|
|
c610719a44 | ||
|
|
36a6c2461a | ||
|
|
c29f22c39e | ||
|
|
30d3062944 | ||
|
|
69ba75abf4 | ||
|
|
e4d486fec5 | ||
|
|
f242144dcf | ||
|
|
02dee2d664 | ||
|
|
a3dd2c3069 | ||
|
|
a23425e8aa | ||
|
|
be79ddc9a3 | ||
|
|
7d71015e8c | ||
|
|
ad54549b51 | ||
|
|
6cf032a164 | ||
|
|
6390d796ac | ||
|
|
98b8411905 | ||
|
|
ddf1029afa | ||
|
|
1effbc5cc9 | ||
|
|
414b645e9f | ||
|
|
398c76f496 | ||
|
|
1bc456dd95 | ||
|
|
2e8421884e | ||
|
|
70d9b193ac | ||
|
|
b49c11004a | ||
|
|
34843eea90 | ||
|
|
2d6d7f31e8 | ||
|
|
7a24cbff1c | ||
|
|
1e7eb2cf1c | ||
|
|
361256e016 | ||
|
|
8838dbd003 | ||
|
|
13a95e1f2b | ||
|
|
1aaa451a3e | ||
|
|
cbba81e54d | ||
|
|
370868dfac | ||
|
|
77f692aae2 | ||
|
|
9318e205ea | ||
|
|
ebcc717c19 | ||
|
|
4c16b564ee | ||
|
|
e2283d1453 | ||
|
|
d891801c5a | ||
|
|
de75386944 | ||
|
|
82dc37de50 | ||
|
|
b6fa7f62dc | ||
|
|
f9e0a95c5e | ||
|
|
b2c6e12647 | ||
|
|
caffb83780 | ||
|
|
8882cb5479 | ||
|
|
75dace2dee | ||
|
|
ad6487d042 | ||
|
|
a91604e8ab | ||
|
|
c364f7c643 | ||
|
|
53435ba184 | ||
|
|
25f8d5519b | ||
|
|
2e4fef6c66 | ||
|
|
80b2b7dc00 | ||
|
|
8585cd8e21 | ||
|
|
9fa2a7eeea | ||
|
|
2d1f74228d | ||
|
|
3d6f7aa0e1 | ||
|
|
3dea60366a | ||
|
|
d4d9a1df4c | ||
|
|
7d6975fd31 | ||
|
|
08be52ed17 | ||
|
|
682a7700c2 | ||
|
|
9d87009216 | ||
|
|
ef86838f62 | ||
|
|
35468233f8 | ||
|
|
26e229867d | ||
|
|
3a1578b3c6 | ||
|
|
d5e3d2cbbc | ||
|
|
c095248176 | ||
|
|
44601c8954 | ||
|
|
135dbb8f07 | ||
|
|
c95682a0c7 | ||
|
|
d177b9f7fa | ||
|
|
9b57615d94 | ||
|
|
c03f3eacd1 | ||
|
|
a26e395932 | ||
|
|
0870b87c96 | ||
|
|
b52a44a7dd | ||
|
|
0a290aafef | ||
|
|
9014d4c410 | ||
|
|
60e58b4f5f | ||
|
|
620e74a6aa | ||
|
|
efa287ed35 | ||
|
|
a24eb9d9b0 | ||
|
|
bd3dab8aae | ||
|
|
4fe1ebaa5b | ||
|
|
c5e944744b | ||
|
|
0c396181f7 | ||
|
|
0034474219 | ||
|
|
8136ad8287 | ||
|
|
681940d466 | ||
|
|
16488506e8 | ||
|
|
122fccc041 | ||
|
|
9d0ad35403 | ||
|
|
f9ec97e026 | ||
|
|
95495a2647 | ||
|
|
e3310a605c | ||
|
|
b55719bf28 | ||
|
|
b957b51279 | ||
|
|
90bcfab369 | ||
|
|
f8a8e30641 | ||
|
|
25cb98e7a7 | ||
|
|
03e1bb7cf9 | ||
|
|
85dbb24f3a | ||
|
|
d817635782 | ||
|
|
2f4f237810 | ||
|
|
5ac94d810f | ||
|
|
39dc46dc25 | ||
|
|
0d9cf725f7 | ||
|
|
e55dbead5b | ||
|
|
7d046e5b30 | ||
|
|
8b4693cf66 | ||
|
|
a1172c9a82 | ||
|
|
1ed2bd33f0 | ||
|
|
4c159bd0ba | ||
|
|
050654b2a9 | ||
|
|
61b261e1b2 | ||
|
|
017b010206 | ||
|
|
00f5189f58 | ||
|
|
4a8309ed1f | ||
|
|
76cfc31a1d | ||
|
|
d9ec434699 | ||
|
|
239f3c40be | ||
|
|
09c8c6e670 | ||
|
|
7e4ad01c94 | ||
|
|
ed98e269ef | ||
|
|
b47d63334f | ||
|
|
5e2a3a5aea | ||
|
|
1a7eb21fc7 | ||
|
|
834a51cdc9 | ||
|
|
1b69d99c06 | ||
|
|
ad189933c6 | ||
|
|
9d86ff32de | ||
|
|
278bb57a58 | ||
|
|
0ba494e0ba | ||
|
|
8b247054bb | ||
|
|
7c5c8e4e0d | ||
|
|
ad106a27f3 | ||
|
|
9d6f61b49e | ||
|
|
02368954a0 | ||
|
|
b477a35a01 | ||
|
|
16622887de | ||
|
|
9059d1fb17 | ||
|
|
df2b008d82 | ||
|
|
0da871efd0 | ||
|
|
1c55349f81 | ||
|
|
9309fa1e81 | ||
|
|
5996189f91 | ||
|
|
bd2b984bfb | ||
|
|
194409a117 | ||
|
|
27978b216d | ||
|
|
c38fa77ce6 | ||
|
|
3eb49f7422 | ||
|
|
1989d615d2 | ||
|
|
239412d265 | ||
|
|
375a419a9e | ||
|
|
875c8ab424 | ||
|
|
c9bfc810ce | ||
|
|
46ecb16949 | ||
|
|
f6dc16f17b | ||
|
|
4eef42f730 | ||
|
|
8612d9a771 | ||
|
|
0caff054f5 | ||
|
|
4aa91ad599 | ||
|
|
7a0864f5c2 | ||
|
|
73dc0dfcf6 | ||
|
|
1ff9a69339 | ||
|
|
179eb5d847 | ||
|
|
52c868828c | ||
|
|
7eea4615b6 | ||
|
|
d9b351df1a | ||
|
|
d6a785b645 | ||
|
|
79db828a01 | ||
|
|
a5ffb0f8dc | ||
|
|
9492fcde74 | ||
|
|
d2456ce4cd | ||
|
|
7de27abc8d | ||
|
|
d8155bc8eb | ||
|
|
cf08e52a92 | ||
|
|
768398b991 | ||
|
|
24c20a19f1 | ||
|
|
8fbcbcd4c0 | ||
|
|
e0da5bb943 | ||
|
|
36fbc4fb82 | ||
|
|
cb11051f42 | ||
|
|
a824781d14 | ||
|
|
600a2c6748 | ||
|
|
77df64bfb5 | ||
|
|
2d6e54903c | ||
|
|
baa2b83df9 | ||
|
|
1ff02446af | ||
|
|
b58c6ba762 |
@@ -17,4 +17,7 @@ ENV/
|
|||||||
.conda/
|
.conda/
|
||||||
README*.md
|
README*.md
|
||||||
dashboard/
|
dashboard/
|
||||||
data/
|
data/
|
||||||
|
changelogs/
|
||||||
|
tests/
|
||||||
|
.ruff_cache/
|
||||||
9
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
9
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -6,7 +6,7 @@ body:
|
|||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: |
|
||||||
欢迎发布插件到插件市场!
|
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
@@ -22,9 +22,10 @@ body:
|
|||||||
插件名:
|
插件名:
|
||||||
插件作者:
|
插件作者:
|
||||||
插件简介:
|
插件简介:
|
||||||
标签: (可选)
|
支持的消息平台:(必填,如 QQ、微信、飞书)
|
||||||
社交链接: (可选, 将会在插件市场作者名称上作为可点击的链接)
|
标签:(可选)
|
||||||
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。
|
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
|
||||||
|
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
|
||||||
|
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -8,3 +8,7 @@
|
|||||||
### Modifications
|
### Modifications
|
||||||
|
|
||||||
<!--简单解释你的改动-->
|
<!--简单解释你的改动-->
|
||||||
|
|
||||||
|
### Check
|
||||||
|
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||||
|
- [ ] 我新增/修复/优化的功能经过良好的测试
|
||||||
|
|||||||
31
.github/workflows/dashboard_ci.yml
vendored
Normal file
31
.github/workflows/dashboard_ci.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
name: AstrBot Dashboard CI
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: npm install, build
|
||||||
|
run: |
|
||||||
|
cd dashboard
|
||||||
|
npm install
|
||||||
|
npm run build
|
||||||
|
|
||||||
|
- name: Inject Commit SHA
|
||||||
|
id: get_sha
|
||||||
|
run: |
|
||||||
|
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||||
|
mkdir -p dashboard/dist/assets
|
||||||
|
echo $COMMIT_SHA > dashboard/dist/assets/version
|
||||||
|
|
||||||
|
- name: Archive production artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist-without-markdown
|
||||||
|
path: |
|
||||||
|
dashboard/dist
|
||||||
|
!dist/**/*.md
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,6 +1,8 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
botpy.log
|
botpy.log
|
||||||
.vscode
|
.vscode
|
||||||
|
.venv*
|
||||||
|
.idea
|
||||||
data_v2.db
|
data_v2.db
|
||||||
data_v3.db
|
data_v3.db
|
||||||
configs/session
|
configs/session
|
||||||
@@ -26,3 +28,5 @@ venv/*
|
|||||||
packages/python_interpreter/workplace
|
packages/python_interpreter/workplace
|
||||||
.venv/*
|
.venv/*
|
||||||
.conda/
|
.conda/
|
||||||
|
.idea
|
||||||
|
pytest.ini
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ ci:
|
|||||||
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.9
|
rev: v0.11.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
17
Dockerfile
17
Dockerfile
@@ -4,19 +4,32 @@ WORKDIR /AstrBot
|
|||||||
COPY . /AstrBot/
|
COPY . /AstrBot/
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
nodejs \
|
||||||
|
npm \
|
||||||
gcc \
|
gcc \
|
||||||
build-essential \
|
build-essential \
|
||||||
python3-dev \
|
python3-dev \
|
||||||
libffi-dev \
|
libffi-dev \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
ca-certificates \
|
||||||
|
bash \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN python -m pip install -r requirements.txt --no-cache-dir
|
RUN python -m pip install uv
|
||||||
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
|
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||||
|
|
||||||
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
|
# 释出 ffmpeg
|
||||||
|
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||||
|
|
||||||
|
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||||
|
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||||
|
|
||||||
EXPOSE 6185
|
EXPOSE 6185
|
||||||
EXPOSE 6186
|
EXPOSE 6186
|
||||||
|
|
||||||
CMD [ "python", "main.py" ]
|
CMD [ "python", "main.py" ]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
35
Dockerfile_with_node
Normal file
35
Dockerfile_with_node
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /AstrBot
|
||||||
|
|
||||||
|
COPY . /AstrBot/
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
build-essential \
|
||||||
|
python3-dev \
|
||||||
|
libffi-dev \
|
||||||
|
libssl-dev \
|
||||||
|
curl \
|
||||||
|
unzip \
|
||||||
|
ca-certificates \
|
||||||
|
bash \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Installation of Node.js
|
||||||
|
ENV NVM_DIR="/root/.nvm"
|
||||||
|
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||||
|
. "$NVM_DIR/nvm.sh" && \
|
||||||
|
nvm install 22 && \
|
||||||
|
nvm use 22
|
||||||
|
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
||||||
|
|
||||||
|
RUN python -m pip install uv
|
||||||
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
|
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
||||||
|
|
||||||
|
EXPOSE 6185
|
||||||
|
EXPOSE 6186
|
||||||
|
|
||||||
|
CMD ["python", "main.py"]
|
||||||
102
README.md
102
README.md
@@ -1,6 +1,6 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -10,13 +10,15 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||

|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||
[](https://codecov.io/gh/Soulter/AstrBot)
|

|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
@@ -26,19 +28,31 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
|
|
||||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||||
|
|
||||||
|
[](https://gitcode.com/Soulter/AstrBot)
|
||||||
|
|
||||||
|
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||||
|
-->
|
||||||
|
|
||||||
|
## ✨ 近期更新
|
||||||
|
|
||||||
|
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||||
|
|
||||||
## ✨ 主要功能
|
## ✨ 主要功能
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||||
|
|
||||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||||
>
|
>
|
||||||
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
|
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||||
|
|
||||||
## ✨ 使用方式
|
## ✨ 使用方式
|
||||||
|
|
||||||
@@ -48,22 +62,33 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
|
|
||||||
#### Windows 一键安装器部署
|
#### Windows 一键安装器部署
|
||||||
|
|
||||||
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||||
|
|
||||||
#### Replit 部署
|
#### 宝塔面板部署
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||||
|
|
||||||
#### CasaOS 部署
|
#### CasaOS 部署
|
||||||
|
|
||||||
社区贡献的部署方式。
|
社区贡献的部署方式。
|
||||||
|
|
||||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
||||||
|
|
||||||
#### 手动部署
|
#### 手动部署
|
||||||
|
|
||||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
推荐使用 `uv`。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||||
|
pip install uv
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||||
|
|
||||||
|
#### Replit 部署
|
||||||
|
|
||||||
|
[](https://repl.it/github/Soulter/AstrBot)
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
@@ -74,7 +99,8 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||||
| 飞书 | ✔ | 群聊 | 文字、图片 |
|
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
|
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||||
| Discord | 🚧 | 计划内 | - |
|
| Discord | 🚧 | 计划内 | - |
|
||||||
| WhatsApp | 🚧 | 计划内 | - |
|
| WhatsApp | 🚧 | 计划内 | - |
|
||||||
@@ -84,7 +110,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
|
|
||||||
| 名称 | 支持性 | 类型 | 备注 |
|
| 名称 | 支持性 | 类型 | 备注 |
|
||||||
| -------- | ------- | ------- | ------- |
|
| -------- | ------- | ------- | ------- |
|
||||||
| OpenAI API | ✔ | 文本生成 | 同时也支持 DeepSeek、Google Gemini、GLM(智谱)、Moonshot(月之暗面)、阿里云百炼、硅基流动、xAI 等所有兼容 OpenAI API 的服务 |
|
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||||
| Claude API | ✔ | 文本生成 | |
|
| Claude API | ✔ | 文本生成 | |
|
||||||
| Google Gemini API | ✔ | 文本生成 | |
|
| Google Gemini API | ✔ | 文本生成 | |
|
||||||
| Dify | ✔ | LLMOps | |
|
| Dify | ✔ | LLMOps | |
|
||||||
@@ -96,6 +122,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||||
|
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||||
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||||
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||||
|
|
||||||
@@ -125,38 +152,36 @@ pre-commit install
|
|||||||
|
|
||||||
## ✨ Demo
|
## ✨ Demo
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
|
||||||
|
|
||||||
<div align='center'>
|
<div align='center'>
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||||
|
|
||||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||||
|
|
||||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
|
||||||
|
|
||||||
_✨ 自然语言待办事项 ✨_
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||||
|
|
||||||
_✨ 插件系统——部分插件展示 ✨_
|
_✨ 插件系统——部分插件展示 ✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
||||||
|
|
||||||
_✨ 管理面板 ✨_
|
_✨ WebUI ✨_
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
_✨ 内置 Web Chat,在线与机器人交互 ✨_
|
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
## ❤️ Special Thanks
|
||||||
|
|
||||||
|
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||||
|
|
||||||
|
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
@@ -174,16 +199,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
|||||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||||
3. Please ensure compliance with local laws and regulations when using this project.
|
3. Please ensure compliance with local laws and regulations when using this project.
|
||||||
|
|
||||||
<!-- ## ✨ ATRI [Beta 测试]
|
|
||||||
|
|
||||||
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
|
||||||
|
|
||||||
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
|
|
||||||
2. 长期记忆
|
|
||||||
3. 表情包理解与回复
|
|
||||||
4. TTS
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ AstrBot is a loosely coupled, asynchronous chatbot and development framework tha
|
|||||||
|
|
||||||
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
||||||
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
||||||
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://astrbot.app/others/dify.html) for easy access to Dify assistants/knowledge bases/workflows.
|
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
|
||||||
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
||||||
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
||||||
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
|||||||
|
|
||||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||||
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from astrbot.core.platform import (
|
|||||||
MessageMember,
|
MessageMember,
|
||||||
MessageType,
|
MessageType,
|
||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
|
Group,
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.platform.register import register_platform_adapter
|
from astrbot.core.platform.register import register_platform_adapter
|
||||||
@@ -18,4 +19,5 @@ __all__ = [
|
|||||||
"MessageType",
|
"MessageType",
|
||||||
"PlatformMetadata",
|
"PlatformMetadata",
|
||||||
"register_platform_adapter",
|
"register_platform_adapter",
|
||||||
|
"Group",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||||
from astrbot.core.provider.entites import (
|
from astrbot.core.provider.entities import (
|
||||||
ProviderRequest,
|
ProviderRequest,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
ProviderMetaData,
|
ProviderMetaData,
|
||||||
|
|||||||
@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
|
|||||||
register_star as register, # 注册插件(Star)
|
register_star as register, # 注册插件(Star)
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.star import Context, Star
|
from astrbot.core.star import Context, Star, StarTools
|
||||||
from astrbot.core.star.config import *
|
from astrbot.core.star.config import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["register", "Context", "Star", "StarTools"]
|
||||||
"register",
|
|
||||||
"Context",
|
|
||||||
"Star",
|
|
||||||
]
|
|
||||||
|
|||||||
7
astrbot/api/util/__init__.py
Normal file
7
astrbot/api/util/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from astrbot.core.utils.session_waiter import (
|
||||||
|
SessionWaiter,
|
||||||
|
SessionController,
|
||||||
|
session_waiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]
|
||||||
@@ -8,6 +8,7 @@ from astrbot.core.db.sqlite import SQLiteDatabase
|
|||||||
from astrbot.core.config.default import DB_PATH
|
from astrbot.core.config.default import DB_PATH
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
|
||||||
|
# 初始化数据存储文件夹
|
||||||
os.makedirs("data", exist_ok=True)
|
os.makedirs("data", exist_ok=True)
|
||||||
|
|
||||||
astrbot_config = AstrBotConfig()
|
astrbot_config = AstrBotConfig()
|
||||||
@@ -19,8 +20,14 @@ if os.environ.get("TESTING", ""):
|
|||||||
logger.setLevel("DEBUG")
|
logger.setLevel("DEBUG")
|
||||||
|
|
||||||
db_helper = SQLiteDatabase(DB_PATH)
|
db_helper = SQLiteDatabase(DB_PATH)
|
||||||
sp = SharedPreferences() # 简单的偏好设置存储
|
sp = (
|
||||||
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
SharedPreferences()
|
||||||
|
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||||
|
pip_installer = PipInstaller(
|
||||||
|
astrbot_config.get("pip_install_arg", ""),
|
||||||
|
astrbot_config.get("pypi_index_url", None),
|
||||||
|
)
|
||||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||||
|
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
VERSION = "3.4.35"
|
VERSION = "3.5.4"
|
||||||
DB_PATH = "data/data_v3.db"
|
DB_PATH = "data/data_v3.db"
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -36,6 +36,9 @@ DEFAULT_CONFIG = {
|
|||||||
"content_cleanup_rule": "",
|
"content_cleanup_rule": "",
|
||||||
},
|
},
|
||||||
"no_permission_reply": True,
|
"no_permission_reply": True,
|
||||||
|
"empty_mention_waiting": True,
|
||||||
|
"friend_message_needs_wake_prefix": False,
|
||||||
|
"ignore_bot_self_message": False,
|
||||||
},
|
},
|
||||||
"provider": [],
|
"provider": [],
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
@@ -47,6 +50,10 @@ DEFAULT_CONFIG = {
|
|||||||
"datetime_system_prompt": True,
|
"datetime_system_prompt": True,
|
||||||
"default_personality": "default",
|
"default_personality": "default",
|
||||||
"prompt_prefix": "",
|
"prompt_prefix": "",
|
||||||
|
"max_context_length": -1,
|
||||||
|
"dequeue_context_length": 1,
|
||||||
|
"streaming_response": False,
|
||||||
|
"streaming_segmented": False,
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -55,6 +62,7 @@ DEFAULT_CONFIG = {
|
|||||||
"provider_tts_settings": {
|
"provider_tts_settings": {
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"provider_id": "",
|
"provider_id": "",
|
||||||
|
"dual_output": False,
|
||||||
},
|
},
|
||||||
"provider_ltm_settings": {
|
"provider_ltm_settings": {
|
||||||
"group_icl_enable": False,
|
"group_icl_enable": False,
|
||||||
@@ -78,21 +86,24 @@ DEFAULT_CONFIG = {
|
|||||||
"admins_id": ["astrbot"],
|
"admins_id": ["astrbot"],
|
||||||
"t2i": False,
|
"t2i": False,
|
||||||
"t2i_word_threshold": 150,
|
"t2i_word_threshold": 150,
|
||||||
|
"t2i_strategy": "remote",
|
||||||
|
"t2i_endpoint": "",
|
||||||
"http_proxy": "",
|
"http_proxy": "",
|
||||||
"dashboard": {
|
"dashboard": {
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"username": "astrbot",
|
"username": "astrbot",
|
||||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||||
|
"host": "0.0.0.0",
|
||||||
"port": 6185,
|
"port": 6185,
|
||||||
},
|
},
|
||||||
"platform": [],
|
"platform": [],
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
"t2i_endpoint": "",
|
|
||||||
"pip_install_arg": "",
|
"pip_install_arg": "",
|
||||||
"plugin_repo_mirror": "",
|
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
|
||||||
"knowledge_db": {},
|
"knowledge_db": {},
|
||||||
"persona": [],
|
"persona": [],
|
||||||
|
"timezone": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -120,6 +131,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"enable": False,
|
"enable": False,
|
||||||
"appid": "",
|
"appid": "",
|
||||||
"secret": "",
|
"secret": "",
|
||||||
|
"callback_server_host": "0.0.0.0",
|
||||||
"port": 6196,
|
"port": 6196,
|
||||||
},
|
},
|
||||||
"aiocqhttp(OneBotv11)": {
|
"aiocqhttp(OneBotv11)": {
|
||||||
@@ -144,10 +156,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"enable": False,
|
"enable": False,
|
||||||
"corpid": "",
|
"corpid": "",
|
||||||
"secret": "",
|
"secret": "",
|
||||||
"port": 6195,
|
|
||||||
"token": "",
|
"token": "",
|
||||||
"encoding_aes_key": "",
|
"encoding_aes_key": "",
|
||||||
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||||
|
"callback_server_host": "0.0.0.0",
|
||||||
|
"port": 6195,
|
||||||
},
|
},
|
||||||
"lark(飞书)": {
|
"lark(飞书)": {
|
||||||
"id": "lark",
|
"id": "lark",
|
||||||
@@ -158,6 +171,13 @@ CONFIG_METADATA_2 = {
|
|||||||
"app_secret": "",
|
"app_secret": "",
|
||||||
"domain": "https://open.feishu.cn",
|
"domain": "https://open.feishu.cn",
|
||||||
},
|
},
|
||||||
|
"dingtalk(钉钉)": {
|
||||||
|
"id": "dingtalk",
|
||||||
|
"type": "dingtalk",
|
||||||
|
"enable": False,
|
||||||
|
"client_id": "",
|
||||||
|
"client_secret": "",
|
||||||
|
},
|
||||||
"telegram": {
|
"telegram": {
|
||||||
"id": "telegram",
|
"id": "telegram",
|
||||||
"type": "telegram",
|
"type": "telegram",
|
||||||
@@ -165,6 +185,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"telegram_token": "your_bot_token",
|
"telegram_token": "your_bot_token",
|
||||||
"start_message": "Hello, I'm AstrBot!",
|
"start_message": "Hello, I'm AstrBot!",
|
||||||
"telegram_api_base_url": "https://api.telegram.org/bot",
|
"telegram_api_base_url": "https://api.telegram.org/bot",
|
||||||
|
"telegram_file_base_url": "https://api.telegram.org/file/bot",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
@@ -210,7 +231,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
||||||
},
|
},
|
||||||
"ws_reverse_host": {
|
"ws_reverse_host": {
|
||||||
"description": "反向 Websocket 主机地址",
|
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
||||||
},
|
},
|
||||||
@@ -231,6 +252,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "平台设置",
|
"description": "平台设置",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
|
"plugin_enable": {
|
||||||
|
"invisible": True, # 隐藏插件启用配置
|
||||||
|
},
|
||||||
"unique_session": {
|
"unique_session": {
|
||||||
"description": "会话隔离",
|
"description": "会话隔离",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -256,6 +280,21 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
||||||
},
|
},
|
||||||
|
"empty_mention_waiting": {
|
||||||
|
"description": "只 @ 机器人是否触发等待回复",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待回复,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
|
||||||
|
},
|
||||||
|
"friend_message_needs_wake_prefix": {
|
||||||
|
"description": "私聊消息是否需要唤醒前缀",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,私聊消息需要唤醒前缀才会被处理,同群聊一样。",
|
||||||
|
},
|
||||||
|
"ignore_bot_self_message": {
|
||||||
|
"description": "是否忽略机器人自身的消息",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||||
|
},
|
||||||
"segmented_reply": {
|
"segmented_reply": {
|
||||||
"description": "分段回复",
|
"description": "分段回复",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -322,7 +361,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"obvious_hint": True,
|
"obvious_hint": True,
|
||||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
|
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
|
||||||
},
|
},
|
||||||
"id_whitelist_log": {
|
"id_whitelist_log": {
|
||||||
"description": "打印白名单日志",
|
"description": "打印白名单日志",
|
||||||
@@ -465,6 +504,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "llama3.1-8b",
|
"model": "llama3.1-8b",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"LM_Studio": {
|
||||||
|
"id": "lm_studio",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": ["lmstudio"],
|
||||||
|
"api_base": "http://localhost:1234/v1",
|
||||||
|
"model_config": {
|
||||||
|
"model": "llama-3.1-8b",
|
||||||
|
},
|
||||||
|
},
|
||||||
"Gemini(OpenAI兼容)": {
|
"Gemini(OpenAI兼容)": {
|
||||||
"id": "gemini_default",
|
"id": "gemini_default",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
@@ -484,7 +533,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://generativelanguage.googleapis.com/",
|
"api_base": "https://generativelanguage.googleapis.com/",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
},
|
||||||
|
"gm_resp_image_modal": False,
|
||||||
|
"gm_native_search": False,
|
||||||
|
"gm_native_coderunner": False,
|
||||||
|
"gm_safety_settings": {
|
||||||
|
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"DeepSeek": {
|
"DeepSeek": {
|
||||||
@@ -548,7 +606,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"dify_api_type": "chat",
|
"dify_api_type": "chat",
|
||||||
"dify_api_key": "",
|
"dify_api_key": "",
|
||||||
"dify_api_base": "https://api.dify.ai/v1",
|
"dify_api_base": "https://api.dify.ai/v1",
|
||||||
"dify_workflow_output_key": "",
|
"dify_workflow_output_key": "astrbot_wf_output",
|
||||||
"dify_query_input_key": "astrbot_text_query",
|
"dify_query_input_key": "astrbot_text_query",
|
||||||
"variables": {},
|
"variables": {},
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
@@ -560,6 +618,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"dashscope_app_type": "agent",
|
"dashscope_app_type": "agent",
|
||||||
"dashscope_api_key": "",
|
"dashscope_api_key": "",
|
||||||
"dashscope_app_id": "",
|
"dashscope_app_id": "",
|
||||||
|
"rag_options": {
|
||||||
|
"pipeline_ids": [],
|
||||||
|
"file_ids": [],
|
||||||
|
"output_reference": False,
|
||||||
|
},
|
||||||
"variables": {},
|
"variables": {},
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
},
|
},
|
||||||
@@ -626,12 +689,118 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "fishaudio_tts_api",
|
"type": "fishaudio_tts_api",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"api_base": "https://api.fish-audio.cn/v1",
|
"api_base": "https://api.fish.audio/v1",
|
||||||
"fishaudio-tts-character": "可莉",
|
"fishaudio-tts-character": "可莉",
|
||||||
"timeout": "20",
|
"timeout": "20",
|
||||||
},
|
},
|
||||||
|
"阿里云百炼_TTS(API)": {
|
||||||
|
"id": "dashscope_tts",
|
||||||
|
"type": "dashscope_tts",
|
||||||
|
"enable": False,
|
||||||
|
"api_key": "",
|
||||||
|
"model": "cosyvoice-v1",
|
||||||
|
"dashscope_tts_voice": "loongstella",
|
||||||
|
"timeout": "20",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
"dashscope_tts_voice": {
|
||||||
|
"description": "语音合成模型",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
|
||||||
|
},
|
||||||
|
"gm_resp_image_modal": {
|
||||||
|
"description": "启用图片模态",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
|
||||||
|
},
|
||||||
|
"gm_native_search": {
|
||||||
|
"description": "启用原生搜索功能",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
||||||
|
"obvious_hint": True,
|
||||||
|
},
|
||||||
|
"gm_native_coderunner": {
|
||||||
|
"description": "启用原生代码执行器",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后所有函数工具将全部失效",
|
||||||
|
"obvious_hint": True,
|
||||||
|
},
|
||||||
|
"gm_safety_settings": {
|
||||||
|
"description": "安全过滤器",
|
||||||
|
"type": "object",
|
||||||
|
"hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。",
|
||||||
|
"items": {
|
||||||
|
"harassment": {
|
||||||
|
"description": "骚扰内容",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "负面或有害评论",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"hate_speech": {
|
||||||
|
"description": "仇恨言论",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "粗鲁、无礼或亵渎性质内容",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"sexually_explicit": {
|
||||||
|
"description": "露骨色情内容",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "包含性行为或其他淫秽内容的引用",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"dangerous_content": {
|
||||||
|
"description": "危险内容",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "宣扬、助长或鼓励有害行为的信息",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"rag_options": {
|
||||||
|
"description": "RAG 选项",
|
||||||
|
"type": "object",
|
||||||
|
"hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)。阿里云百炼应用开启此功能后将无法多轮对话。",
|
||||||
|
"items": {
|
||||||
|
"pipeline_ids": {
|
||||||
|
"description": "知识库 ID 列表",
|
||||||
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。",
|
||||||
|
},
|
||||||
|
"file_ids": {
|
||||||
|
"description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。",
|
||||||
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。",
|
||||||
|
},
|
||||||
|
"output_reference": {
|
||||||
|
"description": "是否输出知识库/文档的引用",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "在每次回答尾部加上引用源。默认为 False。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"sensevoice_hint": {
|
"sensevoice_hint": {
|
||||||
"description": "部署SenseVoice",
|
"description": "部署SenseVoice",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -648,12 +817,14 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。",
|
"hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。",
|
||||||
},
|
},
|
||||||
# "variables": {
|
"variables": {
|
||||||
# "description": "工作流固定输入变量",
|
"description": "工作流固定输入变量",
|
||||||
# "type": "object",
|
"type": "object",
|
||||||
# "obvious_hint": True,
|
"obvious_hint": True,
|
||||||
# "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
"items": {},
|
||||||
# },
|
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
||||||
|
"invisible": True,
|
||||||
|
},
|
||||||
# "fastgpt_app_type": {
|
# "fastgpt_app_type": {
|
||||||
# "description": "应用类型",
|
# "description": "应用类型",
|
||||||
# "type": "string",
|
# "type": "string",
|
||||||
@@ -664,7 +835,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"dashscope_app_type": {
|
"dashscope_app_type": {
|
||||||
"description": "应用类型",
|
"description": "应用类型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "阿里云百炼应用的应用类型。",
|
"hint": "百炼应用的应用类型。",
|
||||||
"options": [
|
"options": [
|
||||||
"agent",
|
"agent",
|
||||||
"agent-arrange",
|
"agent-arrange",
|
||||||
@@ -779,8 +950,8 @@ CONFIG_METADATA_2 = {
|
|||||||
"dify_api_type": {
|
"dify_api_type": {
|
||||||
"description": "Dify 应用类型",
|
"description": "Dify 应用类型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, chatflow, agent, workflow 三种应用类型。",
|
||||||
"options": ["chat", "agent", "workflow"],
|
"options": ["chat", "chatflow", "agent", "workflow"],
|
||||||
},
|
},
|
||||||
"dify_workflow_output_key": {
|
"dify_workflow_output_key": {
|
||||||
"description": "Dify Workflow 输出变量名",
|
"description": "Dify Workflow 输出变量名",
|
||||||
@@ -844,6 +1015,26 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
|
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
|
||||||
},
|
},
|
||||||
|
"max_context_length": {
|
||||||
|
"description": "最多携带对话数量(条)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
|
||||||
|
},
|
||||||
|
"dequeue_context_length": {
|
||||||
|
"description": "丢弃对话数量(条)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超出 最多携带对话数量(条) 时,丢弃多少条记录,用户和AI的一轮聊天记为 1 条。适宜的配置,可以提高超长上下文对话 deepseek 命中缓存效果,理想情况下计费将降低到1/3以下",
|
||||||
|
},
|
||||||
|
"streaming_response": {
|
||||||
|
"description": "启用流式回复",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||||
|
},
|
||||||
|
"streaming_segmented": {
|
||||||
|
"description": "不支持流式回复的平台分段输出",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"persona": {
|
"persona": {
|
||||||
@@ -918,6 +1109,12 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||||
},
|
},
|
||||||
|
"dual_output": {
|
||||||
|
"description": "启用语音和文字双输出",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||||
|
"obvious_hint": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_ltm_settings": {
|
"provider_ltm_settings": {
|
||||||
@@ -937,10 +1134,10 @@ CONFIG_METADATA_2 = {
|
|||||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||||
},
|
},
|
||||||
"image_caption": {
|
"image_caption": {
|
||||||
"description": "启用图像转述(需要模型支持)",
|
"description": "群聊图像转述(需模型支持)",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"obvious_hint": True,
|
"obvious_hint": True,
|
||||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
|
||||||
},
|
},
|
||||||
"image_caption_provider_id": {
|
"image_caption_provider_id": {
|
||||||
"description": "图像转述提供商 ID",
|
"description": "图像转述提供商 ID",
|
||||||
@@ -1024,32 +1221,38 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
||||||
},
|
},
|
||||||
|
"timezone": {
|
||||||
|
"description": "时区",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||||
|
},
|
||||||
"log_level": {
|
"log_level": {
|
||||||
"description": "控制台日志级别",
|
"description": "控制台日志级别",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "控制台输出日志的级别。",
|
"hint": "控制台输出日志的级别。",
|
||||||
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
},
|
},
|
||||||
|
"t2i_strategy": {
|
||||||
|
"description": "文本转图像渲染源",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。",
|
||||||
|
"options": ["remote", "local"],
|
||||||
|
},
|
||||||
"t2i_endpoint": {
|
"t2i_endpoint": {
|
||||||
"description": "文本转图像服务接口",
|
"description": "文本转图像服务接口",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "为空时使用 AstrBot API 服务",
|
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
|
||||||
},
|
},
|
||||||
"pip_install_arg": {
|
"pip_install_arg": {
|
||||||
"description": "pip 安装参数",
|
"description": "pip 安装参数",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。",
|
"hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。",
|
||||||
},
|
},
|
||||||
"plugin_repo_mirror": {
|
"pypi_index_url": {
|
||||||
"description": "插件仓库镜像",
|
"description": "PyPI 软件仓库地址",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
|
"hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/",
|
||||||
"obvious_hint": True,
|
|
||||||
"options": [
|
|
||||||
"default",
|
|
||||||
"https://ghp.ci/",
|
|
||||||
"https://github-mirror.us.kg/",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
"""
|
||||||
|
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
|
||||||
|
|
||||||
|
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||||||
|
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||||
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -11,24 +18,34 @@ class ConversationManager:
|
|||||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||||
|
|
||||||
def __init__(self, db_helper: BaseDatabase):
|
def __init__(self, db_helper: BaseDatabase):
|
||||||
|
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||||
self.db = db_helper
|
self.db = db_helper
|
||||||
self.save_interval = 60 # 每 60 秒保存一次
|
self.save_interval = 60 # 每 60 秒保存一次
|
||||||
self._start_periodic_save()
|
self._start_periodic_save()
|
||||||
|
|
||||||
def _start_periodic_save(self):
|
def _start_periodic_save(self):
|
||||||
|
"""启动定时保存任务"""
|
||||||
asyncio.create_task(self._periodic_save())
|
asyncio.create_task(self._periodic_save())
|
||||||
|
|
||||||
async def _periodic_save(self):
|
async def _periodic_save(self):
|
||||||
|
"""定时保存会话对话映射关系到存储中"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(self.save_interval)
|
await asyncio.sleep(self.save_interval)
|
||||||
self._save_to_storage()
|
self._save_to_storage()
|
||||||
|
|
||||||
def _save_to_storage(self):
|
def _save_to_storage(self):
|
||||||
|
"""保存会话对话映射关系到存储中"""
|
||||||
sp.put("session_conversation", self.session_conversations)
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||||
"""新建对话,并将当前会话的对话转移到新对话"""
|
"""新建对话,并将当前会话的对话转移到新对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
conversation_id = str(uuid.uuid4())
|
conversation_id = str(uuid.uuid4())
|
||||||
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||||
self.session_conversations[unified_msg_origin] = conversation_id
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
@@ -36,14 +53,24 @@ class ConversationManager:
|
|||||||
return conversation_id
|
return conversation_id
|
||||||
|
|
||||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||||
"""切换会话的对话"""
|
"""切换会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
self.session_conversations[unified_msg_origin] = conversation_id
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
sp.put("session_conversation", self.session_conversations)
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
async def delete_conversation(
|
async def delete_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str = None
|
self, unified_msg_origin: str, conversation_id: str = None
|
||||||
):
|
):
|
||||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话"""
|
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||||
@@ -51,23 +78,48 @@ class ConversationManager:
|
|||||||
sp.put("session_conversation", self.session_conversations)
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||||
"""获取会话当前的对话 ID"""
|
"""获取会话当前的对话 ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
return self.session_conversations.get(unified_msg_origin, None)
|
return self.session_conversations.get(unified_msg_origin, None)
|
||||||
|
|
||||||
async def get_conversation(
|
async def get_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str
|
self, unified_msg_origin: str, conversation_id: str
|
||||||
) -> Conversation:
|
) -> Conversation:
|
||||||
"""获取会话的对话"""
|
"""获取会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
Returns:
|
||||||
|
conversation (Conversation): 对话对象
|
||||||
|
"""
|
||||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||||
|
|
||||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||||
"""获取会话的所有对话"""
|
"""获取会话的所有对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversations (List[Conversation]): 对话对象列表
|
||||||
|
"""
|
||||||
return self.db.get_conversations(unified_msg_origin)
|
return self.db.get_conversations(unified_msg_origin)
|
||||||
|
|
||||||
async def update_conversation(
|
async def update_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||||
):
|
):
|
||||||
"""更新会话的对话"""
|
"""更新会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||||
|
"""
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.update_conversation(
|
self.db.update_conversation(
|
||||||
user_id=unified_msg_origin,
|
user_id=unified_msg_origin,
|
||||||
@@ -76,7 +128,12 @@ class ConversationManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||||
"""更新会话的对话标题"""
|
"""更新会话的对话标题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
title (str): 对话标题
|
||||||
|
"""
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.update_conversation_title(
|
self.db.update_conversation_title(
|
||||||
@@ -86,7 +143,12 @@ class ConversationManager:
|
|||||||
async def update_conversation_persona_id(
|
async def update_conversation_persona_id(
|
||||||
self, unified_msg_origin: str, persona_id: str
|
self, unified_msg_origin: str, persona_id: str
|
||||||
):
|
):
|
||||||
"""更新会话的对话 Persona ID"""
|
"""更新会话的对话 Persona ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
persona_id (str): 对话 Persona ID
|
||||||
|
"""
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.update_conversation_persona_id(
|
self.db.update_conversation_persona_id(
|
||||||
@@ -96,6 +158,14 @@ class ConversationManager:
|
|||||||
async def get_human_readable_context(
|
async def get_human_readable_context(
|
||||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||||
):
|
):
|
||||||
|
"""获取人类可读的上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
page (int): 页码
|
||||||
|
page_size (int): 每页大小
|
||||||
|
"""
|
||||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||||
history = json.loads(conversation.history)
|
history = json.loads(conversation.history)
|
||||||
|
|
||||||
@@ -105,7 +175,15 @@ class ConversationManager:
|
|||||||
if record["role"] == "user":
|
if record["role"] == "user":
|
||||||
temp_contexts.append(f"User: {record['content']}")
|
temp_contexts.append(f"User: {record['content']}")
|
||||||
elif record["role"] == "assistant":
|
elif record["role"] == "assistant":
|
||||||
temp_contexts.append(f"Assistant: {record['content']}")
|
if "content" in record and record["content"]:
|
||||||
|
temp_contexts.append(f"Assistant: {record['content']}")
|
||||||
|
elif "tool_calls" in record:
|
||||||
|
tool_calls_str = json.dumps(
|
||||||
|
record["tool_calls"], ensure_ascii=False
|
||||||
|
)
|
||||||
|
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||||
|
else:
|
||||||
|
temp_contexts.append("Assistant: [未知的内容]")
|
||||||
contexts.insert(0, temp_contexts)
|
contexts.insert(0, temp_contexts)
|
||||||
temp_contexts = []
|
temp_contexts = []
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||||
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||||
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 初始化所有组件
|
||||||
|
2. 启动事件总线和任务, 所有任务都在这里运行
|
||||||
|
3. 执行启动完成事件钩子
|
||||||
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -24,32 +35,51 @@ from astrbot.core.star.star_handler import star_map
|
|||||||
|
|
||||||
|
|
||||||
class AstrBotCoreLifecycle:
|
class AstrBotCoreLifecycle:
|
||||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
"""
|
||||||
self.log_broker = log_broker
|
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||||
self.astrbot_config = astrbot_config
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||||
self.db = db
|
EventBus 等。
|
||||||
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||||
|
self.log_broker = log_broker # 初始化日志代理
|
||||||
|
self.astrbot_config = astrbot_config # 初始化配置
|
||||||
|
self.db = db # 初始化数据库
|
||||||
|
|
||||||
|
# 根据环境变量设置代理
|
||||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
os.environ["no_proxy"] = "localhost"
|
os.environ["no_proxy"] = "localhost"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
"""
|
||||||
|
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 初始化日志代理
|
||||||
logger.info("AstrBot v" + VERSION)
|
logger.info("AstrBot v" + VERSION)
|
||||||
if os.environ.get("TESTING", ""):
|
if os.environ.get("TESTING", ""):
|
||||||
logger.setLevel("DEBUG")
|
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
|
||||||
else:
|
else:
|
||||||
logger.setLevel(self.astrbot_config["log_level"])
|
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||||
self.event_queue = Queue()
|
|
||||||
self.event_queue.closed = False
|
|
||||||
|
|
||||||
|
# 初始化事件队列
|
||||||
|
self.event_queue = Queue()
|
||||||
|
|
||||||
|
# 初始化供应商管理器
|
||||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||||
|
|
||||||
|
# 初始化平台管理器
|
||||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||||
|
|
||||||
|
# 初始化知识库管理器
|
||||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||||
|
|
||||||
|
# 初始化对话管理器
|
||||||
self.conversation_manager = ConversationManager(self.db)
|
self.conversation_manager = ConversationManager(self.db)
|
||||||
|
|
||||||
|
# 初始化提供给插件的上下文
|
||||||
self.star_context = Context(
|
self.star_context = Context(
|
||||||
self.event_queue,
|
self.event_queue,
|
||||||
self.astrbot_config,
|
self.astrbot_config,
|
||||||
@@ -59,33 +89,50 @@ class AstrBotCoreLifecycle:
|
|||||||
self.conversation_manager,
|
self.conversation_manager,
|
||||||
self.knowledge_db_manager,
|
self.knowledge_db_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 初始化插件管理器
|
||||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||||
|
|
||||||
|
# 扫描、注册插件、实例化插件类
|
||||||
await self.plugin_manager.reload()
|
await self.plugin_manager.reload()
|
||||||
"""扫描、注册插件、实例化插件类"""
|
|
||||||
|
|
||||||
|
# 根据配置实例化各个 Provider
|
||||||
await self.provider_manager.initialize()
|
await self.provider_manager.initialize()
|
||||||
"""根据配置实例化各个 Provider"""
|
|
||||||
|
|
||||||
|
# 初始化消息事件流水线调度器
|
||||||
self.pipeline_scheduler = PipelineScheduler(
|
self.pipeline_scheduler = PipelineScheduler(
|
||||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||||
)
|
)
|
||||||
await self.pipeline_scheduler.initialize()
|
await self.pipeline_scheduler.initialize()
|
||||||
"""初始化消息事件流水线调度器"""
|
|
||||||
|
|
||||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
# 初始化更新器
|
||||||
|
self.astrbot_updator = AstrBotUpdator()
|
||||||
|
|
||||||
|
# 初始化事件总线
|
||||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||||
|
|
||||||
|
# 记录启动时间
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
|
# 初始化当前任务列表
|
||||||
self.curr_tasks: List[asyncio.Task] = []
|
self.curr_tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
|
# 根据配置实例化各个平台适配器
|
||||||
await self.platform_manager.initialize()
|
await self.platform_manager.initialize()
|
||||||
"""根据配置实例化各个平台适配器"""
|
|
||||||
|
# 初始化关闭控制面板的事件
|
||||||
|
self.dashboard_shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
|
"""加载事件总线和任务并初始化"""
|
||||||
|
|
||||||
|
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||||
|
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||||
event_bus_task = asyncio.create_task(
|
event_bus_task = asyncio.create_task(
|
||||||
self.event_bus.dispatch(), name="event_bus"
|
self.event_bus.dispatch(), name="event_bus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||||
extra_tasks = []
|
extra_tasks = []
|
||||||
for task in self.star_context._register_tasks:
|
for task in self.star_context._register_tasks:
|
||||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||||
@@ -99,17 +146,24 @@ class AstrBotCoreLifecycle:
|
|||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
async def _task_wrapper(self, task: asyncio.Task):
|
async def _task_wrapper(self, task: asyncio.Task):
|
||||||
|
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (asyncio.Task): 要执行的异步任务
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass # 任务被取消, 静默处理
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
|
||||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||||
for line in traceback.format_exc().split("\n"):
|
for line in traceback.format_exc().split("\n"):
|
||||||
logger.error(f"| {line}")
|
logger.error(f"| {line}")
|
||||||
logger.error("-------")
|
logger.error("-------")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
|
||||||
self._load()
|
self._load()
|
||||||
logger.info("AstrBot 启动完成。")
|
logger.info("AstrBot 启动完成。")
|
||||||
|
|
||||||
@@ -126,15 +180,29 @@ class AstrBotCoreLifecycle:
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# 同时运行curr_tasks中的所有任务
|
||||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.event_queue.closed = True
|
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
|
||||||
|
# 请求停止所有正在运行的异步任务
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
await self.provider_manager.terminate()
|
for plugin in self.plugin_manager.context.get_all_stars():
|
||||||
|
try:
|
||||||
|
await self.plugin_manager._terminate_plugin(plugin)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.provider_manager.terminate()
|
||||||
|
await self.platform_manager.terminate()
|
||||||
|
self.dashboard_shutdown_event.set()
|
||||||
|
|
||||||
|
# 再次遍历curr_tasks等待每个任务真正结束
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
@@ -143,13 +211,17 @@ class AstrBotCoreLifecycle:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||||
|
|
||||||
def restart(self):
|
async def restart(self):
|
||||||
self.event_queue.closed = True
|
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||||
|
await self.provider_manager.terminate()
|
||||||
|
await self.platform_manager.terminate()
|
||||||
|
self.dashboard_shutdown_event.set()
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
def load_platform(self) -> List[asyncio.Task]:
|
def load_platform(self) -> List[asyncio.Task]:
|
||||||
|
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||||
tasks = []
|
tasks = []
|
||||||
platform_insts = self.platform_manager.get_insts()
|
platform_insts = self.platform_manager.get_insts()
|
||||||
for platform_inst in platform_insts:
|
for platform_inst in platform_insts:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List, Dict, Any, Tuple
|
||||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||||
|
|
||||||
|
|
||||||
@@ -117,3 +117,45 @@ class BaseDatabase(abc.ABC):
|
|||||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||||
"""更新 Conversation Persona ID"""
|
"""更新 Conversation Persona ID"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_all_conversations(
|
||||||
|
self, page: int = 1, page_size: int = 20
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取所有对话,支持分页
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码,从1开始
|
||||||
|
page_size: 每页数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
platforms: List[str] = None,
|
||||||
|
message_types: List[str] = None,
|
||||||
|
search_query: str = None,
|
||||||
|
exclude_ids: List[str] = None,
|
||||||
|
exclude_platforms: List[str] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取筛选后的对话列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
platforms: 平台筛选列表
|
||||||
|
message_types: 消息类型筛选列表
|
||||||
|
search_query: 搜索关键词
|
||||||
|
exclude_ids: 排除的用户ID列表
|
||||||
|
exclude_platforms: 排除的平台列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
112
astrbot/core/db/plugin/sqlite_impl.py
Normal file
112
astrbot/core/db/plugin/sqlite_impl.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import json
|
||||||
|
import aiosqlite
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from .plugin_storage import PluginStorage
|
||||||
|
|
||||||
|
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||||
|
|
||||||
|
|
||||||
|
class SQLitePluginStorage(PluginStorage):
|
||||||
|
"""插件数据的 SQLite 存储实现类。
|
||||||
|
|
||||||
|
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||||
|
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None # Standalone instance of the class
|
||||||
|
_db_conn = None
|
||||||
|
db_path = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""
|
||||||
|
创建或获取 SQLitePluginStorage 的单例实例。
|
||||||
|
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||||
|
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||||
|
"""
|
||||||
|
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||||
|
cls._instance.db_path = DBPATH
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
async def _init_db(self):
|
||||||
|
"""初始化数据库连接(只执行一次)"""
|
||||||
|
if SQLitePluginStorage._db_conn is None:
|
||||||
|
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||||
|
await self._setup_db()
|
||||||
|
|
||||||
|
async def _setup_db(self):
|
||||||
|
"""
|
||||||
|
异步初始化数据库。
|
||||||
|
|
||||||
|
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||||
|
其中 plugin 和 key 组合作为主键。
|
||||||
|
"""
|
||||||
|
await self._db_conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||||
|
plugin TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
PRIMARY KEY (plugin, key)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
await self._db_conn.commit()
|
||||||
|
|
||||||
|
async def set(self, plugin: str, key: str, value: Any):
|
||||||
|
"""
|
||||||
|
异步存储数据。
|
||||||
|
|
||||||
|
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||||
|
值会被序列化为 JSON 字符串后存储。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin: 插件标识符
|
||||||
|
key: 数据键名
|
||||||
|
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||||
|
"""
|
||||||
|
await self._init_db()
|
||||||
|
await self._db_conn.execute(
|
||||||
|
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||||
|
(plugin, key, json.dumps(value)),
|
||||||
|
)
|
||||||
|
await self._db_conn.commit()
|
||||||
|
|
||||||
|
async def get(self, plugin: str, key: str) -> Any:
|
||||||
|
"""
|
||||||
|
异步获取数据。
|
||||||
|
|
||||||
|
从数据库中获取指定插件和键名对应的值,
|
||||||
|
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin: 插件标识符
|
||||||
|
key: 数据键名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 存储的数据值,如果未找到则返回 None
|
||||||
|
"""
|
||||||
|
await self._init_db()
|
||||||
|
async with self._db_conn.execute(
|
||||||
|
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||||
|
(plugin, key),
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return json.loads(row[0]) if row else None
|
||||||
|
|
||||||
|
async def delete(self, plugin: str, key: str):
|
||||||
|
"""
|
||||||
|
异步删除数据。
|
||||||
|
|
||||||
|
从数据库中删除指定插件和键名对应的数据项。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin: 插件标识符
|
||||||
|
key: 要删除的数据键名
|
||||||
|
"""
|
||||||
|
await self._init_db()
|
||||||
|
await self._db_conn.execute(
|
||||||
|
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||||
|
)
|
||||||
|
await self._db_conn.commit()
|
||||||
@@ -6,6 +6,8 @@ from typing import List
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Platform:
|
class Platform:
|
||||||
|
"""平台使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
@@ -13,6 +15,8 @@ class Platform:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Provider:
|
class Provider:
|
||||||
|
"""供应商使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
@@ -20,6 +24,8 @@ class Provider:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Plugin:
|
class Plugin:
|
||||||
|
"""插件使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
@@ -27,6 +33,8 @@ class Plugin:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Command:
|
class Command:
|
||||||
|
"""命令使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
|
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
|
||||||
from . import BaseDatabase
|
from . import BaseDatabase
|
||||||
from typing import Tuple
|
from typing import Tuple, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase(BaseDatabase):
|
class SQLiteDatabase(BaseDatabase):
|
||||||
@@ -128,24 +128,23 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
where_clause = ""
|
conditions = []
|
||||||
if session_id or provider_type:
|
params = []
|
||||||
where_clause += " WHERE "
|
|
||||||
has = False
|
if session_id:
|
||||||
if session_id:
|
conditions.append("session_id = ?")
|
||||||
where_clause += f"session_id = '{session_id}'"
|
params.append(session_id)
|
||||||
has = True
|
|
||||||
if provider_type:
|
if provider_type:
|
||||||
if has:
|
conditions.append("provider_type = ?")
|
||||||
where_clause += " AND "
|
params.append(provider_type)
|
||||||
where_clause += f"provider_type = '{provider_type}'"
|
|
||||||
|
sql = "SELECT * FROM llm_history"
|
||||||
|
if conditions:
|
||||||
|
sql += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
c.execute(sql, params)
|
||||||
|
|
||||||
c.execute(
|
|
||||||
"""
|
|
||||||
SELECT * FROM llm_history
|
|
||||||
"""
|
|
||||||
+ where_clause
|
|
||||||
)
|
|
||||||
res = c.fetchall()
|
res = c.fetchall()
|
||||||
histories = []
|
histories = []
|
||||||
for row in res:
|
for row in res:
|
||||||
@@ -389,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
if res:
|
if res:
|
||||||
return ATRIVision(*res)
|
return ATRIVision(*res)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_all_conversations(
|
||||||
|
self, page: int = 1, page_size: int = 20
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||||
|
try:
|
||||||
|
c = self.conn.cursor()
|
||||||
|
except sqlite3.ProgrammingError:
|
||||||
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取总记录数
|
||||||
|
c.execute("""
|
||||||
|
SELECT COUNT(*) FROM webchat_conversation
|
||||||
|
""")
|
||||||
|
total_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 获取分页数据,按更新时间降序排序
|
||||||
|
c.execute(
|
||||||
|
"""
|
||||||
|
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||||
|
FROM webchat_conversation
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(page_size, offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = c.fetchall()
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||||
|
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
|
||||||
|
safe_cid = str(cid) if cid else "unknown"
|
||||||
|
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||||
|
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"cid": safe_cid,
|
||||||
|
"title": title or f"对话 {display_cid}",
|
||||||
|
"persona_id": persona_id or "",
|
||||||
|
"created_at": created_at or 0,
|
||||||
|
"updated_at": updated_at or 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total_count
|
||||||
|
|
||||||
|
except Exception as _:
|
||||||
|
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||||
|
return [], 0
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
platforms: List[str] = None,
|
||||||
|
message_types: List[str] = None,
|
||||||
|
search_query: str = None,
|
||||||
|
exclude_ids: List[str] = None,
|
||||||
|
exclude_platforms: List[str] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取筛选后的对话列表"""
|
||||||
|
try:
|
||||||
|
c = self.conn.cursor()
|
||||||
|
except sqlite3.ProgrammingError:
|
||||||
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建查询条件
|
||||||
|
where_clauses = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
# 平台筛选
|
||||||
|
if platforms and len(platforms) > 0:
|
||||||
|
platform_conditions = []
|
||||||
|
for platform in platforms:
|
||||||
|
platform_conditions.append("user_id LIKE ?")
|
||||||
|
params.append(f"{platform}:%")
|
||||||
|
|
||||||
|
if platform_conditions:
|
||||||
|
where_clauses.append(f"({' OR '.join(platform_conditions)})")
|
||||||
|
|
||||||
|
# 消息类型筛选
|
||||||
|
if message_types and len(message_types) > 0:
|
||||||
|
message_type_conditions = []
|
||||||
|
for msg_type in message_types:
|
||||||
|
message_type_conditions.append("user_id LIKE ?")
|
||||||
|
params.append(f"%:{msg_type}:%")
|
||||||
|
|
||||||
|
if message_type_conditions:
|
||||||
|
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
||||||
|
|
||||||
|
# 搜索关键词
|
||||||
|
if search_query:
|
||||||
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
|
where_clauses.append(
|
||||||
|
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||||
|
)
|
||||||
|
search_param = f"%{search_query}%"
|
||||||
|
params.extend([search_param, search_param, search_param, search_param])
|
||||||
|
|
||||||
|
# 排除特定用户ID
|
||||||
|
if exclude_ids and len(exclude_ids) > 0:
|
||||||
|
for exclude_id in exclude_ids:
|
||||||
|
where_clauses.append("user_id NOT LIKE ?")
|
||||||
|
params.append(f"{exclude_id}%")
|
||||||
|
|
||||||
|
# 排除特定平台
|
||||||
|
if exclude_platforms and len(exclude_platforms) > 0:
|
||||||
|
for exclude_platform in exclude_platforms:
|
||||||
|
where_clauses.append("user_id NOT LIKE ?")
|
||||||
|
params.append(f"{exclude_platform}:%")
|
||||||
|
|
||||||
|
# 构建完整的 WHERE 子句
|
||||||
|
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||||
|
|
||||||
|
# 构建计数查询
|
||||||
|
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
|
||||||
|
|
||||||
|
# 获取总记录数
|
||||||
|
c.execute(count_sql, params)
|
||||||
|
total_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 构建分页数据查询
|
||||||
|
data_sql = f"""
|
||||||
|
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||||
|
FROM webchat_conversation
|
||||||
|
{where_sql}
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
query_params = params + [page_size, offset]
|
||||||
|
|
||||||
|
# 获取分页数据
|
||||||
|
c.execute(data_sql, query_params)
|
||||||
|
rows = c.fetchall()
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||||
|
# 确保 cid 是字符串类型,否则使用一个默认值
|
||||||
|
safe_cid = str(cid) if cid else "unknown"
|
||||||
|
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||||
|
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"cid": safe_cid,
|
||||||
|
"title": title or f"对话 {display_cid}",
|
||||||
|
"persona_id": persona_id or "",
|
||||||
|
"created_at": created_at or 0,
|
||||||
|
"updated_at": updated_at or 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total_count
|
||||||
|
|
||||||
|
except Exception as _:
|
||||||
|
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||||
|
return [], 0
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
|||||||
@@ -38,11 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||||
user_id TEXT,
|
user_id TEXT, -- 会话 id
|
||||||
cid TEXT,
|
cid TEXT, -- 对话 id
|
||||||
history TEXT,
|
history TEXT,
|
||||||
created_at INTEGER,
|
created_at INTEGER,
|
||||||
updated_at INTEGER,
|
updated_at INTEGER,
|
||||||
title TEXT,
|
title TEXT,
|
||||||
persona_id TEXT
|
persona_id TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
|
PRAGMA encoding = 'UTF-8';
|
||||||
@@ -1,3 +1,16 @@
|
|||||||
|
"""
|
||||||
|
事件总线, 用于处理事件的分发和处理
|
||||||
|
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||||
|
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
|
class:
|
||||||
|
EventBus: 事件总线, 用于处理事件的分发和处理
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 维护一个异步队列, 来接受各种消息事件
|
||||||
|
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||||
@@ -6,21 +19,38 @@ from .platform import AstrMessageEvent
|
|||||||
|
|
||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
|
"""事件总线: 用于处理事件的分发和处理
|
||||||
|
|
||||||
|
维护一个异步队列, 来接受各种消息事件
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue # 事件队列
|
||||||
self.pipeline_scheduler = pipeline_scheduler
|
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||||
|
|
||||||
async def dispatch(self):
|
async def dispatch(self):
|
||||||
|
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||||
while True:
|
while True:
|
||||||
event: AstrMessageEvent = await self.event_queue.get()
|
event: AstrMessageEvent = (
|
||||||
self._print_event(event)
|
await self.event_queue.get()
|
||||||
asyncio.create_task(self.pipeline_scheduler.execute(event))
|
) # 从事件队列中获取新的事件
|
||||||
|
self._print_event(event) # 打印日志
|
||||||
|
asyncio.create_task(
|
||||||
|
self.pipeline_scheduler.execute(event)
|
||||||
|
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
def _print_event(self, event: AstrMessageEvent):
|
def _print_event(self, event: AstrMessageEvent):
|
||||||
|
"""用于记录事件信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
"""
|
||||||
|
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||||
if event.get_sender_name():
|
if event.get_sender_name():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||||
)
|
)
|
||||||
|
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
|
"""
|
||||||
|
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
||||||
|
2. 运行核心生命周期任务和仪表板服务器
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from .server import AstrBotDashboard
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core import LogBroker
|
from astrbot.core import LogBroker
|
||||||
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
|
|
||||||
|
|
||||||
class AstrBotDashBoardLifecycle:
|
class InitialLoader:
|
||||||
|
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
|
||||||
|
|
||||||
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
self.dashboard_server = None
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||||
@@ -25,11 +34,15 @@ class AstrBotDashBoardLifecycle:
|
|||||||
logger.critical(traceback.format_exc())
|
logger.critical(traceback.format_exc())
|
||||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||||
|
|
||||||
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
|
self.dashboard_server = AstrBotDashboard(
|
||||||
task = asyncio.gather(core_task, self.dashboard_server.run())
|
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||||
|
)
|
||||||
|
task = asyncio.gather(
|
||||||
|
core_task, self.dashboard_server.run()
|
||||||
|
) # 启动核心任务和仪表板服务器
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await task
|
await task # 整个AstrBot在这里运行
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("🌈 正在关闭 AstrBot...")
|
logger.info("🌈 正在关闭 AstrBot...")
|
||||||
await core_lifecycle.stop()
|
await core_lifecycle.stop()
|
||||||
@@ -1,11 +1,38 @@
|
|||||||
|
"""
|
||||||
|
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||||
|
|
||||||
|
const:
|
||||||
|
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
||||||
|
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
|
||||||
|
|
||||||
|
class:
|
||||||
|
LogBroker: 日志代理类, 用于缓存和分发日志消息
|
||||||
|
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
LogManager: 日志管理器, 用于创建和配置日志记录器
|
||||||
|
|
||||||
|
function:
|
||||||
|
is_plugin_path: 检查文件路径是否来自插件目录
|
||||||
|
get_short_level_name: 将日志级别名称转换为四个字母的缩写
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
|
||||||
|
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
|
||||||
|
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
|
||||||
|
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import colorlog
|
import colorlog
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
# 日志缓存大小
|
||||||
CACHED_SIZE = 200
|
CACHED_SIZE = 200
|
||||||
|
# 日志颜色配置
|
||||||
log_color_config = {
|
log_color_config = {
|
||||||
"DEBUG": "green",
|
"DEBUG": "green",
|
||||||
"INFO": "bold_cyan",
|
"INFO": "bold_cyan",
|
||||||
@@ -17,13 +44,57 @@ log_color_config = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_plugin_path(pathname):
|
||||||
|
"""检查文件路径是否来自插件目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pathname (str): 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
||||||
|
"""
|
||||||
|
if not pathname:
|
||||||
|
return False
|
||||||
|
|
||||||
|
norm_path = os.path.normpath(pathname)
|
||||||
|
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_short_level_name(level_name):
|
||||||
|
"""将日志级别名称转换为四个字母的缩写
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 四个字母的日志级别缩写
|
||||||
|
"""
|
||||||
|
level_map = {
|
||||||
|
"DEBUG": "DBUG",
|
||||||
|
"INFO": "INFO",
|
||||||
|
"WARNING": "WARN",
|
||||||
|
"ERROR": "ERRO",
|
||||||
|
"CRITICAL": "CRIT",
|
||||||
|
}
|
||||||
|
return level_map.get(level_name, level_name[:4].upper())
|
||||||
|
|
||||||
|
|
||||||
class LogBroker:
|
class LogBroker:
|
||||||
|
"""日志代理类, 用于缓存和分发日志消息
|
||||||
|
|
||||||
|
发布-订阅模式
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_cache = deque(maxlen=CACHED_SIZE)
|
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||||
self.subscribers: List[Queue] = []
|
self.subscribers: List[Queue] = [] # 订阅者列表
|
||||||
|
|
||||||
def register(self) -> Queue:
|
def register(self) -> Queue:
|
||||||
"""给每个订阅者返回一个带有日志缓存的队列"""
|
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Queue: 订阅者的队列, 可用于接收日志消息
|
||||||
|
"""
|
||||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||||
for log in self.log_cache:
|
for log in self.log_cache:
|
||||||
q.put_nowait(log)
|
q.put_nowait(log)
|
||||||
@@ -31,11 +102,20 @@ class LogBroker:
|
|||||||
return q
|
return q
|
||||||
|
|
||||||
def unregister(self, q: Queue):
|
def unregister(self, q: Queue):
|
||||||
"""取消订阅"""
|
"""取消订阅
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Queue): 需要取消订阅的队列
|
||||||
|
"""
|
||||||
self.subscribers.remove(q)
|
self.subscribers.remove(q)
|
||||||
|
|
||||||
def publish(self, log_entry: str):
|
def publish(self, log_entry: dict):
|
||||||
"""发布消息"""
|
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
||||||
|
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
||||||
|
"""
|
||||||
self.log_cache.append(log_entry)
|
self.log_cache.append(log_entry)
|
||||||
for q in self.subscribers:
|
for q in self.subscribers:
|
||||||
try:
|
try:
|
||||||
@@ -45,44 +125,124 @@ class LogBroker:
|
|||||||
|
|
||||||
|
|
||||||
class LogQueueHandler(logging.Handler):
|
class LogQueueHandler(logging.Handler):
|
||||||
|
"""日志处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
|
||||||
|
继承自 logging.Handler
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, log_broker: LogBroker):
|
def __init__(self, log_broker: LogBroker):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
|
||||||
|
这个方法会在每次日志记录时被调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||||
|
"""
|
||||||
log_entry = self.format(record)
|
log_entry = self.format(record)
|
||||||
self.log_broker.publish(log_entry)
|
self.log_broker.publish(
|
||||||
|
{
|
||||||
|
"level": record.levelname,
|
||||||
|
"time": record.asctime,
|
||||||
|
"data": log_entry,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogManager:
|
class LogManager:
|
||||||
|
"""日志管理器, 用于创建和配置日志记录器
|
||||||
|
|
||||||
|
提供了获取默认日志记录器logger和设置队列处理器的方法
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def GetLogger(cls, log_name: str = "default"):
|
def GetLogger(cls, log_name: str = "default"):
|
||||||
|
"""获取指定名称的日志记录器logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_name (str): 日志记录器的名称, 默认为 "default"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 返回配置好的日志记录器
|
||||||
|
"""
|
||||||
logger = logging.getLogger(log_name)
|
logger = logging.getLogger(log_name)
|
||||||
|
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
||||||
if logger.hasHandlers():
|
if logger.hasHandlers():
|
||||||
return logger
|
return logger
|
||||||
console_handler = logging.StreamHandler()
|
# 如果logger没有处理器
|
||||||
console_handler.setLevel(logging.DEBUG)
|
console_handler = logging.StreamHandler(
|
||||||
|
sys.stdout
|
||||||
|
) # 创建一个StreamHandler用于控制台输出
|
||||||
|
console_handler.setLevel(
|
||||||
|
logging.DEBUG
|
||||||
|
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||||
|
|
||||||
|
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
||||||
console_formatter = colorlog.ColoredFormatter(
|
console_formatter = colorlog.ColoredFormatter(
|
||||||
fmt="%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
|
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
|
||||||
datefmt="%H:%M:%S",
|
datefmt="%H:%M:%S",
|
||||||
log_colors=log_color_config,
|
log_colors=log_color_config,
|
||||||
)
|
)
|
||||||
console_handler.setFormatter(console_formatter)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
class PluginFilter(logging.Filter):
|
||||||
logger.addHandler(console_handler)
|
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
record.plugin_tag = (
|
||||||
|
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
class FileNameFilter(logging.Filter):
|
||||||
|
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
||||||
|
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
|
||||||
|
|
||||||
|
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||||
|
def filter(self, record):
|
||||||
|
dirname = os.path.dirname(record.pathname)
|
||||||
|
record.filename = (
|
||||||
|
os.path.basename(dirname)
|
||||||
|
+ "."
|
||||||
|
+ os.path.basename(record.pathname).replace(".py", "")
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
class LevelNameFilter(logging.Filter):
|
||||||
|
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
|
||||||
|
|
||||||
|
# 添加短日志级别名称
|
||||||
|
def filter(self, record):
|
||||||
|
record.short_levelname = get_short_level_name(record.levelname)
|
||||||
|
return True
|
||||||
|
|
||||||
|
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
|
||||||
|
logger.addFilter(PluginFilter()) # 添加插件过滤器
|
||||||
|
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
|
||||||
|
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
|
||||||
|
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
|
||||||
|
logger.addHandler(console_handler) # 添加处理器到logger
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
||||||
|
"""设置队列处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger (logging.Logger): 日志记录器
|
||||||
|
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
||||||
|
"""
|
||||||
handler = LogQueueHandler(log_broker)
|
handler = LogQueueHandler(log_broker)
|
||||||
handler.setLevel(logging.DEBUG)
|
handler.setLevel(logging.DEBUG)
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
handler.setFormatter(logger.handlers[0].formatter)
|
handler.setFormatter(logger.handlers[0].formatter)
|
||||||
else:
|
else:
|
||||||
|
# 为队列处理器设置相同格式的formatter
|
||||||
handler.setFormatter(
|
handler.setFormatter(
|
||||||
logging.Formatter(
|
logging.Formatter(
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ SOFTWARE.
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
import typing as T
|
import typing as T
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
|
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||||
|
|
||||||
|
|
||||||
class ComponentType(Enum):
|
class ComponentType(Enum):
|
||||||
@@ -59,6 +61,8 @@ class ComponentType(Enum):
|
|||||||
TTS = "TTS"
|
TTS = "TTS"
|
||||||
Unknown = "Unknown"
|
Unknown = "Unknown"
|
||||||
|
|
||||||
|
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageComponent(BaseModel):
|
class BaseMessageComponent(BaseModel):
|
||||||
type: ComponentType
|
type: ComponentType
|
||||||
@@ -146,6 +150,52 @@ class Record(BaseMessageComponent):
|
|||||||
return Record(file=url, **_)
|
return Record(file=url, **_)
|
||||||
raise Exception("not a valid url")
|
raise Exception("not a valid url")
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 语音的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
if self.file and self.file.startswith("file:///"):
|
||||||
|
file_path = self.file[8:]
|
||||||
|
return file_path
|
||||||
|
elif self.file and self.file.startswith("http"):
|
||||||
|
file_path = await download_image_by_url(self.file)
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
elif self.file and self.file.startswith("base64://"):
|
||||||
|
bs64_data = self.file.removeprefix("base64://")
|
||||||
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
|
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
elif os.path.exists(self.file):
|
||||||
|
file_path = self.file
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
|
||||||
|
async def convert_to_base64(self) -> str:
|
||||||
|
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
"""
|
||||||
|
# convert to base64
|
||||||
|
if self.file and self.file.startswith("file:///"):
|
||||||
|
bs64_data = file_to_base64(self.file[8:])
|
||||||
|
elif self.file and self.file.startswith("http"):
|
||||||
|
file_path = await download_image_by_url(self.file)
|
||||||
|
bs64_data = file_to_base64(file_path)
|
||||||
|
elif self.file and self.file.startswith("base64://"):
|
||||||
|
bs64_data = self.file
|
||||||
|
elif os.path.exists(self.file):
|
||||||
|
bs64_data = file_to_base64(self.file)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
bs64_data = bs64_data.removeprefix("base64://")
|
||||||
|
return bs64_data
|
||||||
|
|
||||||
|
|
||||||
class Video(BaseMessageComponent):
|
class Video(BaseMessageComponent):
|
||||||
type: ComponentType = "Video"
|
type: ComponentType = "Video"
|
||||||
@@ -279,10 +329,6 @@ class Image(BaseMessageComponent):
|
|||||||
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
||||||
|
|
||||||
def __init__(self, file: T.Optional[str], **_):
|
def __init__(self, file: T.Optional[str], **_):
|
||||||
# for k in _.keys():
|
|
||||||
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
|
|
||||||
# (k == "c" and _[k] not in [2, 3]):
|
|
||||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
|
||||||
super().__init__(file=file, **_)
|
super().__init__(file=file, **_)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -307,14 +353,78 @@ class Image(BaseMessageComponent):
|
|||||||
def fromIO(IO):
|
def fromIO(IO):
|
||||||
return Image.fromBytes(IO.read())
|
return Image.fromBytes(IO.read())
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 图片的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
url = self.url if self.url else self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
image_file_path = url[8:]
|
||||||
|
return image_file_path
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(url)
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
elif url and url.startswith("base64://"):
|
||||||
|
bs64_data = url.removeprefix("base64://")
|
||||||
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
|
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||||
|
with open(image_file_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
elif os.path.exists(url):
|
||||||
|
image_file_path = url
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
|
||||||
|
async def convert_to_base64(self) -> str:
|
||||||
|
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
"""
|
||||||
|
# convert to base64
|
||||||
|
url = self.url if self.url else self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
bs64_data = file_to_base64(url[8:])
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(url)
|
||||||
|
bs64_data = file_to_base64(image_file_path)
|
||||||
|
elif url and url.startswith("base64://"):
|
||||||
|
bs64_data = url
|
||||||
|
elif os.path.exists(url):
|
||||||
|
bs64_data = file_to_base64(url)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
bs64_data = bs64_data.removeprefix("base64://")
|
||||||
|
return bs64_data
|
||||||
|
|
||||||
|
|
||||||
class Reply(BaseMessageComponent):
|
class Reply(BaseMessageComponent):
|
||||||
type: ComponentType = "Reply"
|
type: ComponentType = "Reply"
|
||||||
id: T.Union[str, int]
|
id: T.Union[str, int]
|
||||||
text: T.Optional[str] = ""
|
"""所引用的消息 ID"""
|
||||||
qq: T.Optional[int] = 0
|
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||||
|
"""引用的消息段列表"""
|
||||||
|
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||||
|
"""引用的消息发送者 ID"""
|
||||||
|
sender_nickname: T.Optional[str] = ""
|
||||||
|
"""引用的消息发送者昵称"""
|
||||||
time: T.Optional[int] = 0
|
time: T.Optional[int] = 0
|
||||||
|
"""引用的消息发送时间"""
|
||||||
|
message_str: T.Optional[str] = ""
|
||||||
|
"""解析后的纯文本消息字符串"""
|
||||||
|
sender_str: T.Optional[str] = ""
|
||||||
|
"""被引用的消息纯文本"""
|
||||||
|
|
||||||
|
text: T.Optional[str] = ""
|
||||||
|
"""deprecated"""
|
||||||
|
qq: T.Optional[int] = 0
|
||||||
|
"""deprecated"""
|
||||||
seq: T.Optional[int] = 0
|
seq: T.Optional[int] = 0
|
||||||
|
"""deprecated"""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -353,16 +463,22 @@ class Node(BaseMessageComponent):
|
|||||||
id: T.Optional[int] = 0 # 忽略
|
id: T.Optional[int] = 0 # 忽略
|
||||||
name: T.Optional[str] = "" # qq昵称
|
name: T.Optional[str] = "" # qq昵称
|
||||||
uin: T.Optional[int] = 0 # qq号
|
uin: T.Optional[int] = 0 # qq号
|
||||||
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
|
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||||
time: T.Optional[int] = 0
|
time: T.Optional[int] = 0
|
||||||
|
|
||||||
def __init__(self, content: T.Union[str, list], **_):
|
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
_content = ""
|
_content = None
|
||||||
for chain in content:
|
if all(isinstance(item, Node) for item in content):
|
||||||
_content += chain.toString()
|
_content = [node.toDict() for node in content]
|
||||||
|
else:
|
||||||
|
_content = ""
|
||||||
|
for chain in content:
|
||||||
|
_content += chain.toString()
|
||||||
content = _content
|
content = _content
|
||||||
|
elif isinstance(content, Node):
|
||||||
|
content = content.toDict()
|
||||||
super().__init__(content=content, **_)
|
super().__init__(content=content, **_)
|
||||||
|
|
||||||
def toString(self):
|
def toString(self):
|
||||||
@@ -449,6 +565,16 @@ class File(BaseMessageComponent):
|
|||||||
super().__init__(name=name, file=file)
|
super().__init__(name=name, file=file)
|
||||||
|
|
||||||
|
|
||||||
|
class WechatEmoji(BaseMessageComponent):
|
||||||
|
type: ComponentType = "WechatEmoji"
|
||||||
|
md5: T.Optional[str] = ""
|
||||||
|
md5_len: T.Optional[int] = 0
|
||||||
|
cdnurl: T.Optional[str] = ""
|
||||||
|
|
||||||
|
def __init__(self, **_):
|
||||||
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
ComponentTypes = {
|
ComponentTypes = {
|
||||||
"plain": Plain,
|
"plain": Plain,
|
||||||
"text": Plain,
|
"text": Plain,
|
||||||
@@ -477,4 +603,5 @@ ComponentTypes = {
|
|||||||
"tts": TTS,
|
"tts": TTS,
|
||||||
"unknown": Unknown,
|
"unknown": Unknown,
|
||||||
"file": File,
|
"file": File,
|
||||||
|
"WechatEmoji": WechatEmoji,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
import enum
|
import enum
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union, AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
|
from astrbot.core.message.components import (
|
||||||
|
BaseMessageComponent,
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
At,
|
||||||
|
AtAll,
|
||||||
|
)
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +37,30 @@ class MessageChain:
|
|||||||
self.chain.append(Plain(message))
|
self.chain.append(Plain(message))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def at(self, name: str, qq: Union[str, int]):
|
||||||
|
"""添加一条 At 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().at("张三", "12345678910")
|
||||||
|
# 输出 @张三
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chain.append(At(name=name, qq=qq))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def at_all(self):
|
||||||
|
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().at_all()
|
||||||
|
# 输出 @所有人
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chain.append(AtAll())
|
||||||
|
return self
|
||||||
|
|
||||||
@deprecated("请使用 message 方法代替。")
|
@deprecated("请使用 message 方法代替。")
|
||||||
def error(self, message: str):
|
def error(self, message: str):
|
||||||
"""添加一条错误消息到消息链 `chain` 中
|
"""添加一条错误消息到消息链 `chain` 中
|
||||||
@@ -77,6 +107,34 @@ class MessageChain:
|
|||||||
self.use_t2i_ = use_t2i
|
self.use_t2i_ = use_t2i
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_plain_text(self) -> str:
|
||||||
|
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||||
|
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||||
|
|
||||||
|
def squash_plain(self):
|
||||||
|
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||||
|
if not self.chain:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_chain = []
|
||||||
|
first_plain = None
|
||||||
|
plain_texts = []
|
||||||
|
|
||||||
|
for comp in self.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
if first_plain is None:
|
||||||
|
first_plain = comp
|
||||||
|
new_chain.append(comp)
|
||||||
|
plain_texts.append(comp.text)
|
||||||
|
else:
|
||||||
|
new_chain.append(comp)
|
||||||
|
|
||||||
|
if first_plain is not None:
|
||||||
|
first_plain.text = "".join(plain_texts)
|
||||||
|
|
||||||
|
self.chain = new_chain
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class EventResultType(enum.Enum):
|
class EventResultType(enum.Enum):
|
||||||
"""用于描述事件处理的结果类型。
|
"""用于描述事件处理的结果类型。
|
||||||
@@ -97,6 +155,10 @@ class ResultContentType(enum.Enum):
|
|||||||
"""调用 LLM 产生的结果"""
|
"""调用 LLM 产生的结果"""
|
||||||
GENERAL_RESULT = enum.auto()
|
GENERAL_RESULT = enum.auto()
|
||||||
"""普通的消息结果"""
|
"""普通的消息结果"""
|
||||||
|
STREAMING_RESULT = enum.auto()
|
||||||
|
"""调用 LLM 产生的流式结果"""
|
||||||
|
STREAMING_FINISH= enum.auto()
|
||||||
|
"""流式输出完成"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -118,6 +180,9 @@ class MessageEventResult(MessageChain):
|
|||||||
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async_stream: Optional[AsyncGenerator] = None
|
||||||
|
"""异步流"""
|
||||||
|
|
||||||
def stop_event(self) -> "MessageEventResult":
|
def stop_event(self) -> "MessageEventResult":
|
||||||
"""终止事件传播。"""
|
"""终止事件传播。"""
|
||||||
self.result_type = EventResultType.STOP
|
self.result_type = EventResultType.STOP
|
||||||
@@ -134,6 +199,11 @@ class MessageEventResult(MessageChain):
|
|||||||
"""
|
"""
|
||||||
return self.result_type == EventResultType.STOP
|
return self.result_type == EventResultType.STOP
|
||||||
|
|
||||||
|
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||||
|
"""设置异步流。"""
|
||||||
|
self.async_stream = stream
|
||||||
|
return self
|
||||||
|
|
||||||
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
||||||
"""设置事件处理的结果类型。
|
"""设置事件处理的结果类型。
|
||||||
|
|
||||||
@@ -147,9 +217,6 @@ class MessageEventResult(MessageChain):
|
|||||||
"""是否为 LLM 结果。"""
|
"""是否为 LLM 结果。"""
|
||||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||||
|
|
||||||
def get_plain_text(self) -> str:
|
|
||||||
"""获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
|
||||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
|
||||||
|
|
||||||
|
|
||||||
|
# 为了兼容旧版代码,保留 CommandResult 的别名
|
||||||
CommandResult = MessageEventResult
|
CommandResult = MessageEventResult
|
||||||
|
|||||||
@@ -7,16 +7,19 @@ from .waking_check.stage import WakingCheckStage
|
|||||||
from .whitelist_check.stage import WhitelistCheckStage
|
from .whitelist_check.stage import WhitelistCheckStage
|
||||||
from .rate_limit_check.stage import RateLimitStage
|
from .rate_limit_check.stage import RateLimitStage
|
||||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||||
|
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||||
from .preprocess_stage.stage import PreProcessStage
|
from .preprocess_stage.stage import PreProcessStage
|
||||||
from .process_stage.stage import ProcessStage
|
from .process_stage.stage import ProcessStage
|
||||||
from .result_decorate.stage import ResultDecorateStage
|
from .result_decorate.stage import ResultDecorateStage
|
||||||
from .respond.stage import RespondStage
|
from .respond.stage import RespondStage
|
||||||
|
|
||||||
|
# 管道阶段顺序
|
||||||
STAGES_ORDER = [
|
STAGES_ORDER = [
|
||||||
"WakingCheckStage", # 检查是否需要唤醒
|
"WakingCheckStage", # 检查是否需要唤醒
|
||||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||||
"RateLimitStage", # 检查会话是否超过频率限制
|
"RateLimitStage", # 检查会话是否超过频率限制
|
||||||
"ContentSafetyCheckStage", # 检查内容安全
|
"ContentSafetyCheckStage", # 检查内容安全
|
||||||
|
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||||
"PreProcessStage", # 预处理
|
"PreProcessStage", # 预处理
|
||||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||||
@@ -28,6 +31,7 @@ __all__ = [
|
|||||||
"WhitelistCheckStage",
|
"WhitelistCheckStage",
|
||||||
"RateLimitStage",
|
"RateLimitStage",
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
|
"PlatformCompatibilityStage",
|
||||||
"PreProcessStage",
|
"PreProcessStage",
|
||||||
"ProcessStage",
|
"ProcessStage",
|
||||||
"ResultDecorateStage",
|
"ResultDecorateStage",
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
import re
|
import re
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import base64
|
|
||||||
from . import ContentSafetyStrategy
|
from . import ContentSafetyStrategy
|
||||||
|
|
||||||
|
|
||||||
@@ -11,13 +8,13 @@ class KeywordsStrategy(ContentSafetyStrategy):
|
|||||||
if extra_keywords is None:
|
if extra_keywords is None:
|
||||||
extra_keywords = []
|
extra_keywords = []
|
||||||
self.keywords.extend(extra_keywords)
|
self.keywords.extend(extra_keywords)
|
||||||
keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
|
# keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
|
||||||
# internal keywords
|
# internal keywords
|
||||||
if os.path.exists(keywords_path):
|
# if os.path.exists(keywords_path):
|
||||||
with open(keywords_path, "r", encoding="utf-8") as f:
|
# with open(keywords_path, "r", encoding="utf-8") as f:
|
||||||
self.keywords.extend(
|
# self.keywords.extend(
|
||||||
json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||||
)
|
# )
|
||||||
|
|
||||||
def check(self, content: str) -> bool:
|
def check(self, content: str) -> bool:
|
||||||
for keyword in self.keywords:
|
for keyword in self.keywords:
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
ewogICAgImtleXdvcmRzIjogWwogICAgICAgICLkuaDov5HlubMiLAogICAgICAgICLog6HplKbmtpsiLAogICAgICAgICLmsZ/ms73msJEiLAogICAgICAgICLmuKnlrrblrp0iLAogICAgICAgICLmnY7lhYvlvLoiLAogICAgICAgICLmnY7plb/mmKUiLAogICAgICAgICLmr5vms73kuJwiLAogICAgICAgICLpgpPlsI/lubMiLAogICAgICAgICLlkajmganmnaUiLAogICAgICAgICLnpL7kvJrkuLvkuYkiLAogICAgICAgICLlhbHkuqflhZoiLAogICAgICAgICLlhbHkuqfkuLvkuYkiLAogICAgICAgICLlpKfpmYblrpjmlrkiLAogICAgICAgICLljJfkuqzmlL/mnYMiLAogICAgICAgICLkuK3ljY7luJ3lm70iLAogICAgICAgICLkuK3lm73mlL/lupwiLAogICAgICAgICLlhbHni5ciLAogICAgICAgICLlha3lm5vkuovku7YiLAogICAgICAgICLlpKnlronpl6giLAogICAgICAgICLlha3lm5siLAogICAgICAgICLmlL/msrvlsYDluLjlp5QiLAogICAgICAgICLlrabmva4iLAogICAgICAgICLlhavkuZ0iLAogICAgICAgICLkuozljYHlpKciLAogICAgICAgICLmsJHov5vlhZoiLAogICAgICAgICLlj7Dni6wiLAogICAgICAgICLlj7Dmub7ni6znq4siLAogICAgICAgICLlj7Dmub7lm70iLAogICAgICAgICLlm73msJHlhZoiLAogICAgICAgICLlj7Dmub7msJHlm70iLAogICAgICAgICLkuK3ljY7msJHlm70iLAogICAgICAgICJwb3JuaHViIiwKICAgICAgICAiUG9ybmh1YiIsCiAgICAgICAgIuS9nOeIsSIsCiAgICAgICAgIuWBmueIsSIsCiAgICAgICAgIuaAp+S6pCIsCiAgICAgICAgIuiHquaFsCIsCiAgICAgICAgIumYtOiMjiIsCiAgICAgICAgIua3q+WmhyIsCiAgICAgICAgIuiCm+S6pCIsCiAgICAgICAgIuS6pOmFjSIsCiAgICAgICAgIuaAp+WFs+ezuyIsCiAgICAgICAgIuaAp+a0u+WKqCIsCiAgICAgICAgIuiJsuaDhSIsCiAgICAgICAgIuiJsuWbviIsCiAgICAgICAgIuijuOS9kyIsCiAgICAgICAgIuWwj+eptCIsCiAgICAgICAgIua3q+iNoSIsCiAgICAgICAgIuaAp+eIsSIsCiAgICAgICAgIua4r+eLrCIsCiAgICAgICAgIuazlei9ruWKnyIsCiAgICAgICAgIuWFreWbmyIKICAgIF0KfQ==
|
|
||||||
@@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineContext:
|
class PipelineContext:
|
||||||
astrbot_config: AstrBotConfig
|
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||||
plugin_manager: PluginManager
|
|
||||||
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||||
|
plugin_manager: PluginManager # 插件管理器对象
|
||||||
|
|||||||
56
astrbot/core/pipeline/platform_compatibility/stage.py
Normal file
56
astrbot/core/pipeline/platform_compatibility/stage.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from ..stage import Stage, register_stage
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from typing import Union, AsyncGenerator
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
@register_stage
|
||||||
|
class PlatformCompatibilityStage(Stage):
|
||||||
|
"""检查所有处理器的平台兼容性。
|
||||||
|
|
||||||
|
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
"""初始化平台兼容性检查阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
# 获取当前平台ID
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
|
||||||
|
# 获取已激活的处理器
|
||||||
|
activated_handlers = event.get_extra("activated_handlers")
|
||||||
|
if activated_handlers is None:
|
||||||
|
activated_handlers = []
|
||||||
|
|
||||||
|
# 标记不兼容的处理器
|
||||||
|
for handler in activated_handlers:
|
||||||
|
if not isinstance(handler, StarHandlerMetadata):
|
||||||
|
continue
|
||||||
|
# 检查处理器是否在当前平台启用
|
||||||
|
enabled = handler.is_enabled_for_platform(platform_id)
|
||||||
|
if not enabled:
|
||||||
|
if handler.handler_module_path in star_map:
|
||||||
|
plugin_name = star_map[handler.handler_module_path].name
|
||||||
|
logger.debug(
|
||||||
|
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||||
|
)
|
||||||
|
# 设置处理器为平台不兼容状态
|
||||||
|
# TODO: 更好的标记方式
|
||||||
|
handler.platform_compatible = False
|
||||||
|
else:
|
||||||
|
# 确保处理器为平台兼容状态
|
||||||
|
handler.platform_compatible = True
|
||||||
|
|
||||||
|
# 更新已激活的处理器列表
|
||||||
|
event.set_extra("activated_handlers", activated_handlers)
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ...context import PipelineContext
|
from ...context import PipelineContext
|
||||||
@@ -11,11 +12,18 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
ResultContentType,
|
ResultContentType,
|
||||||
|
MessageChain,
|
||||||
)
|
)
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
|
from astrbot.core.provider.entities import (
|
||||||
|
ProviderRequest,
|
||||||
|
LLMResponse,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
AssistantMessageSegment,
|
||||||
|
ToolCallsResult,
|
||||||
|
)
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
|
|
||||||
@@ -27,6 +35,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
||||||
"wake_prefix"
|
"wake_prefix"
|
||||||
] # str
|
] # str
|
||||||
|
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||||
|
"max_context_length"
|
||||||
|
] # int
|
||||||
|
self.dequeue_context_length = min(
|
||||||
|
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||||
|
self.max_context_length - 1,
|
||||||
|
) # int
|
||||||
|
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||||
|
"streaming_response"
|
||||||
|
] # bool
|
||||||
|
|
||||||
for bwp in self.bot_wake_prefixs:
|
for bwp in self.bot_wake_prefixs:
|
||||||
if self.provider_wake_prefix.startswith(bwp):
|
if self.provider_wake_prefix.startswith(bwp):
|
||||||
@@ -48,12 +66,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
if event.get_extra("provider_request"):
|
if event.get_extra("provider_request"):
|
||||||
req = event.get_extra("provider_request")
|
req = event.get_extra("provider_request")
|
||||||
assert isinstance(req, ProviderRequest), (
|
assert isinstance(
|
||||||
"provider_request 必须是 ProviderRequest 类型。"
|
req, ProviderRequest
|
||||||
)
|
), "provider_request 必须是 ProviderRequest 类型。"
|
||||||
|
|
||||||
if req.conversation:
|
if req.conversation:
|
||||||
req.contexts = json.loads(req.conversation.history)
|
all_contexts = json.loads(req.conversation.history)
|
||||||
|
req.contexts = self._process_tool_message_pairs(
|
||||||
|
all_contexts, remove_tags=True
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
req = ProviderRequest(prompt="", image_urls=[])
|
req = ProviderRequest(prompt="", image_urls=[])
|
||||||
if self.provider_wake_prefix:
|
if self.provider_wake_prefix:
|
||||||
@@ -63,8 +85,8 @@ class LLMRequestSubStage(Stage):
|
|||||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||||
for comp in event.message_obj.message:
|
for comp in event.message_obj.message:
|
||||||
if isinstance(comp, Image):
|
if isinstance(comp, Image):
|
||||||
image_url = comp.url if comp.url else comp.file
|
image_path = await comp.convert_to_file_path()
|
||||||
req.image_urls.append(image_url)
|
req.image_urls.append(image_path)
|
||||||
|
|
||||||
# 获取对话上下文
|
# 获取对话上下文
|
||||||
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
||||||
@@ -74,10 +96,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
conversation_id = await self.conv_manager.new_conversation(
|
conversation_id = await self.conv_manager.new_conversation(
|
||||||
event.unified_msg_origin
|
event.unified_msg_origin
|
||||||
)
|
)
|
||||||
req.session_id = event.unified_msg_origin
|
|
||||||
conversation = await self.conv_manager.get_conversation(
|
conversation = await self.conv_manager.get_conversation(
|
||||||
event.unified_msg_origin, conversation_id
|
event.unified_msg_origin, conversation_id
|
||||||
)
|
)
|
||||||
|
if not conversation:
|
||||||
|
conversation_id = await self.conv_manager.new_conversation(
|
||||||
|
event.unified_msg_origin
|
||||||
|
)
|
||||||
|
conversation = await self.conv_manager.get_conversation(
|
||||||
|
event.unified_msg_origin, conversation_id
|
||||||
|
)
|
||||||
req.conversation = conversation
|
req.conversation = conversation
|
||||||
req.contexts = json.loads(conversation.history)
|
req.contexts = json.loads(conversation.history)
|
||||||
|
|
||||||
@@ -88,8 +116,10 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
# 执行请求 LLM 前事件钩子。
|
# 执行请求 LLM 前事件钩子。
|
||||||
# 装饰 system_prompt 等功能
|
# 装饰 system_prompt 等功能
|
||||||
|
# 获取当前平台ID
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnLLMRequestEvent
|
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
@@ -109,108 +139,313 @@ class LLMRequestSubStage(Stage):
|
|||||||
if isinstance(req.contexts, str):
|
if isinstance(req.contexts, str):
|
||||||
req.contexts = json.loads(req.contexts)
|
req.contexts = json.loads(req.contexts)
|
||||||
|
|
||||||
try:
|
# max context length
|
||||||
logger.debug(f"提供商请求 Payload: {req}")
|
if (
|
||||||
if _nested:
|
self.max_context_length != -1 # -1 为不限制
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
and len(req.contexts) // 2 > self.max_context_length
|
||||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
):
|
||||||
|
logger.debug("上下文长度超过限制,将截断。")
|
||||||
|
req.contexts = req.contexts[
|
||||||
|
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||||
|
]
|
||||||
|
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||||
|
index = next((i for i, item in enumerate(req.contexts) if item.get("role") == "user"), None)
|
||||||
|
if index is not None and index > 0:
|
||||||
|
req.contexts = req.contexts[index:]
|
||||||
|
|
||||||
# 执行 LLM 响应后的事件钩子。
|
# session_id
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
if not req.session_id:
|
||||||
EventType.OnLLMResponseEvent
|
req.session_id = event.unified_msg_origin
|
||||||
)
|
|
||||||
for handler in handlers:
|
async def requesting(req: ProviderRequest):
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
need_loop = True
|
||||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
while need_loop:
|
||||||
|
need_loop = False
|
||||||
|
logger.debug(f"提供商请求 Payload: {req}")
|
||||||
|
|
||||||
|
final_llm_response = None
|
||||||
|
|
||||||
|
if self.streaming_response:
|
||||||
|
stream = provider.text_chat_stream(**req.__dict__)
|
||||||
|
async for llm_response in stream:
|
||||||
|
if llm_response.is_chunk:
|
||||||
|
if llm_response.result_chain:
|
||||||
|
yield llm_response.result_chain # MessageChain
|
||||||
|
else:
|
||||||
|
yield MessageChain().message(
|
||||||
|
llm_response.completion_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_llm_response = llm_response
|
||||||
|
else:
|
||||||
|
final_llm_response = await provider.text_chat(
|
||||||
|
**req.__dict__
|
||||||
|
) # 请求 LLM
|
||||||
|
|
||||||
|
if not final_llm_response:
|
||||||
|
raise Exception("LLM response is None.")
|
||||||
|
|
||||||
|
# 执行 LLM 响应后的事件钩子。
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnLLMResponseEvent
|
||||||
)
|
)
|
||||||
await handler.handler(event, llm_response)
|
for handler in handlers:
|
||||||
except BaseException:
|
try:
|
||||||
logger.error(traceback.format_exc())
|
logger.debug(
|
||||||
|
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler(event, final_llm_response)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.streaming_response:
|
||||||
|
# 流式输出的处理
|
||||||
|
async for result in self._handle_llm_stream_response(
|
||||||
|
event, req, final_llm_response
|
||||||
|
):
|
||||||
|
if isinstance(result, ProviderRequest):
|
||||||
|
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||||
|
req = result
|
||||||
|
need_loop = True
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
# 非流式输出的处理
|
||||||
|
async for result in self._handle_llm_response(
|
||||||
|
event, req, final_llm_response
|
||||||
|
):
|
||||||
|
if isinstance(result, ProviderRequest):
|
||||||
|
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||||
|
req = result
|
||||||
|
need_loop = True
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
llm_tick=1,
|
||||||
|
model_name=provider.get_model(),
|
||||||
|
provider_type=provider.meta().type,
|
||||||
)
|
)
|
||||||
return
|
)
|
||||||
|
|
||||||
# 保存到历史记录
|
# 保存到历史记录
|
||||||
await self._save_to_history(event, req, llm_response)
|
await self._save_to_history(event, req, final_llm_response)
|
||||||
|
|
||||||
await Metric.upload(
|
except BaseException as e:
|
||||||
llm_tick=1,
|
logger.error(traceback.format_exc())
|
||||||
model_name=provider.get_model(),
|
event.set_result(
|
||||||
provider_type=provider.meta().type,
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.streaming_response:
|
||||||
|
event.set_extra("tool_call_result", None)
|
||||||
|
async for _ in requesting(req):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||||
|
.set_async_stream(requesting(req))
|
||||||
)
|
)
|
||||||
|
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||||
|
yield
|
||||||
|
|
||||||
if llm_response.role == "assistant":
|
if event.get_extra("tool_call_result"):
|
||||||
# text completion
|
event.set_result(event.get_extra("tool_call_result"))
|
||||||
|
event.set_extra("tool_call_result", None)
|
||||||
|
yield
|
||||||
|
|
||||||
|
async def _handle_llm_response(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理非流式 LLM 响应。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# text completion
|
||||||
|
if llm_response.result_chain:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult(
|
||||||
|
chain=llm_response.result_chain.chain
|
||||||
|
).set_result_content_type(ResultContentType.LLM_RESULT)
|
||||||
|
)
|
||||||
|
else:
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult()
|
MessageEventResult()
|
||||||
.message(llm_response.completion_text)
|
.message(llm_response.completion_text)
|
||||||
.set_result_content_type(ResultContentType.LLM_RESULT)
|
.set_result_content_type(ResultContentType.LLM_RESULT)
|
||||||
)
|
)
|
||||||
elif llm_response.role == "err":
|
elif llm_response.role == "err":
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "tool":
|
||||||
|
# 处理函数工具调用
|
||||||
|
async for result in self._handle_function_tools(event, req, llm_response):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
async def _handle_llm_stream_response(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理流式 LLM 响应。
|
||||||
|
|
||||||
|
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# text completion
|
||||||
|
if llm_response.result_chain:
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult().message(
|
MessageEventResult(
|
||||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
chain=llm_response.result_chain.chain
|
||||||
|
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.message(llm_response.completion_text)
|
||||||
|
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "err":
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "tool":
|
||||||
|
# 处理函数工具调用
|
||||||
|
async for result in self._handle_function_tools(event, req, llm_response):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
async def _handle_function_tools(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理函数工具调用。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
# function calling
|
||||||
|
tool_call_result: list[ToolCallMessageSegment] = []
|
||||||
|
logger.info(
|
||||||
|
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||||
|
)
|
||||||
|
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||||
|
llm_response.tools_call_name,
|
||||||
|
llm_response.tools_call_args,
|
||||||
|
llm_response.tools_call_ids,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
func_tool = req.func_tool.get_func(func_tool_name)
|
||||||
|
if func_tool.origin == "mcp":
|
||||||
|
logger.info(
|
||||||
|
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||||
)
|
)
|
||||||
)
|
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||||
elif llm_response.role == "tool":
|
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||||
# function calling
|
if res:
|
||||||
function_calling_result = {}
|
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||||
logger.info(
|
tool_call_result.append(
|
||||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
ToolCallMessageSegment(
|
||||||
)
|
role="tool",
|
||||||
for func_tool_name, func_tool_args in zip(
|
tool_call_id=func_tool_id,
|
||||||
llm_response.tools_call_name, llm_response.tools_call_args
|
content=res.content[0].text,
|
||||||
):
|
)
|
||||||
func_tool = req.func_tool.get_func(func_tool_name)
|
)
|
||||||
|
else:
|
||||||
|
# 获取处理器,过滤掉平台不兼容的处理器
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
star_md = star_map.get(func_tool.handler_module_path)
|
||||||
|
if (
|
||||||
|
star_md and
|
||||||
|
platform_id in star_md.supported_platforms
|
||||||
|
and not star_md.supported_platforms[platform_id]
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||||
|
)
|
||||||
|
# 直接跳过,不添加任何消息到tool_call_result
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||||
)
|
)
|
||||||
try:
|
# 尝试调用工具函数
|
||||||
# 尝试调用工具函数
|
wrapper = self._call_handler(
|
||||||
wrapper = self._call_handler(
|
self.ctx, event, func_tool.handler, **func_tool_args
|
||||||
self.ctx, event, func_tool.handler, **func_tool_args
|
)
|
||||||
)
|
async for resp in wrapper:
|
||||||
async for resp in wrapper:
|
if resp is not None: # 有 return 返回
|
||||||
if resp is not None: # 有 return 返回
|
tool_call_result.append(
|
||||||
function_calling_result[func_tool_name] = resp
|
ToolCallMessageSegment(
|
||||||
else:
|
role="tool",
|
||||||
yield # 有生成器返回
|
tool_call_id=func_tool_id,
|
||||||
event.clear_result() # 清除上一个 handler 的结果
|
content=resp,
|
||||||
except BaseException as e:
|
)
|
||||||
logger.warning(traceback.format_exc())
|
)
|
||||||
function_calling_result[func_tool_name] = (
|
else:
|
||||||
"When calling the function, an error occurred: " + str(e)
|
res = event.get_result()
|
||||||
)
|
if res and res.chain:
|
||||||
if function_calling_result:
|
event.set_extra("tool_call_result", res)
|
||||||
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
|
yield # 有生成器返回
|
||||||
# 我们重新执行一遍这个 stage
|
event.clear_result() # 清除上一个 handler 的结果
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
except BaseException as e:
|
||||||
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
|
logger.warning(traceback.format_exc())
|
||||||
for tool_name, tool_result in function_calling_result.items():
|
tool_call_result.append(
|
||||||
extra_prompt += (
|
ToolCallMessageSegment(
|
||||||
f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
role="tool",
|
||||||
)
|
tool_call_id=func_tool_id,
|
||||||
req.prompt += extra_prompt
|
content=f"error: {str(e)}",
|
||||||
async for _ in self.process(event, _nested=True):
|
)
|
||||||
yield
|
|
||||||
else:
|
|
||||||
if llm_response.completion_text:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(llm_response.completion_text)
|
|
||||||
)
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(
|
|
||||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
if tool_call_result:
|
||||||
|
# 函数调用结果
|
||||||
|
req.func_tool = None # 暂时不支持递归工具调用
|
||||||
|
assistant_msg_seg = AssistantMessageSegment(
|
||||||
|
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||||
)
|
)
|
||||||
return
|
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||||
|
req.tool_calls_result = ToolCallsResult(
|
||||||
|
tool_calls_info=assistant_msg_seg,
|
||||||
|
tool_calls_result=tool_call_result,
|
||||||
|
)
|
||||||
|
yield req # 再次执行 LLM 请求
|
||||||
|
else:
|
||||||
|
if llm_response.completion_text:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(llm_response.completion_text)
|
||||||
|
)
|
||||||
|
|
||||||
async def _save_to_history(
|
async def _save_to_history(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||||
@@ -220,9 +455,23 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
if llm_response.role == "assistant":
|
if llm_response.role == "assistant":
|
||||||
# 文本回复
|
# 文本回复
|
||||||
contexts = req.contexts
|
contexts = req.contexts.copy()
|
||||||
new_record = {"role": "user", "content": req.prompt}
|
contexts.append(await req.assemble_context())
|
||||||
contexts.append(new_record)
|
|
||||||
|
# 记录并标记函数调用结果
|
||||||
|
if req.tool_calls_result:
|
||||||
|
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||||
|
|
||||||
|
# 添加标记
|
||||||
|
for message in tool_calls_messages:
|
||||||
|
message["_tool_call_history"] = True
|
||||||
|
|
||||||
|
processed_tool_messages = self._process_tool_message_pairs(
|
||||||
|
tool_calls_messages, remove_tags=False
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.extend(processed_tool_messages)
|
||||||
|
|
||||||
contexts.append(
|
contexts.append(
|
||||||
{"role": "assistant", "content": llm_response.completion_text}
|
{"role": "assistant", "content": llm_response.completion_text}
|
||||||
)
|
)
|
||||||
@@ -232,3 +481,59 @@ class LLMRequestSubStage(Stage):
|
|||||||
await self.conv_manager.update_conversation(
|
await self.conv_manager.update_conversation(
|
||||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||||
|
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): 消息列表
|
||||||
|
remove_tags (bool): 是否移除_tool_call_history标记
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(messages):
|
||||||
|
current_msg = messages[i]
|
||||||
|
|
||||||
|
# 普通消息直接添加
|
||||||
|
if "_tool_call_history" not in current_msg:
|
||||||
|
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 工具调用消息成对处理
|
||||||
|
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||||
|
assistant_msg = current_msg.copy()
|
||||||
|
|
||||||
|
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||||
|
del assistant_msg["_tool_call_history"]
|
||||||
|
|
||||||
|
related_tools = []
|
||||||
|
j = i + 1
|
||||||
|
while (
|
||||||
|
j < len(messages)
|
||||||
|
and messages[j].get("role") == "tool"
|
||||||
|
and "_tool_call_history" in messages[j]
|
||||||
|
):
|
||||||
|
tool_msg = messages[j].copy()
|
||||||
|
|
||||||
|
if remove_tags:
|
||||||
|
del tool_msg["_tool_call_history"]
|
||||||
|
|
||||||
|
related_tools.append(tool_msg)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# 成对的时候添加到结果
|
||||||
|
if related_tools:
|
||||||
|
result.append(assistant_msg)
|
||||||
|
result.extend(related_tools)
|
||||||
|
|
||||||
|
i = j # 跳过已处理
|
||||||
|
else:
|
||||||
|
# 单独的tool消息
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
if not handlers_parsed_params:
|
if not handlers_parsed_params:
|
||||||
handlers_parsed_params = {}
|
handlers_parsed_params = {}
|
||||||
|
|
||||||
for handler in activated_handlers:
|
for handler in activated_handlers:
|
||||||
|
# 检查处理器是否在当前平台兼容
|
||||||
|
if (
|
||||||
|
hasattr(handler, "platform_compatible")
|
||||||
|
and handler.platform_compatible is False
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||||
try:
|
try:
|
||||||
if handler.handler_module_path not in star_map:
|
if handler.handler_module_path not in star_map:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage
|
|||||||
from .method.star_request import StarRequestSubStage
|
from .method.star_request import StarRequestSubStage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,22 +2,63 @@ import random
|
|||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import traceback
|
import traceback
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage, Stage
|
from ..stage import register_stage, Stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.message.components import Plain, Reply, At
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class RespondStage(Stage):
|
class RespondStage(Stage):
|
||||||
|
# 组件类型到其非空判断函数的映射
|
||||||
|
_component_validators = {
|
||||||
|
Comp.Plain: lambda comp: bool(
|
||||||
|
comp.text and comp.text.strip()
|
||||||
|
), # 纯文本消息需要strip
|
||||||
|
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
||||||
|
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||||
|
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||||
|
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||||
|
Comp.AtAll: lambda comp: True, # @所有人
|
||||||
|
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||||
|
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||||
|
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||||
|
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||||
|
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||||
|
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||||
|
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||||
|
Comp.Music: lambda comp: bool(comp._type)
|
||||||
|
and bool(comp.url)
|
||||||
|
and bool(comp.audio), # 音乐
|
||||||
|
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||||
|
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||||
|
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||||
|
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||||
|
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||||
|
Comp.Node: lambda comp: bool(comp.name)
|
||||||
|
and comp.uin != 0
|
||||||
|
and bool(comp.content), # 一个转发节点
|
||||||
|
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||||
|
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||||
|
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||||
|
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||||
|
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||||
|
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||||
|
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||||
|
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||||
|
}
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext):
|
async def initialize(self, ctx: PipelineContext):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
self.config = ctx.astrbot_config
|
||||||
|
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||||
|
|
||||||
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||||
"reply_with_mention"
|
"reply_with_mention"
|
||||||
@@ -62,7 +103,7 @@ class RespondStage(Stage):
|
|||||||
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
||||||
"""分段回复 计算间隔时间"""
|
"""分段回复 计算间隔时间"""
|
||||||
if self.interval_method == "log":
|
if self.interval_method == "log":
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Comp.Plain):
|
||||||
wc = await self._word_cnt(comp.text)
|
wc = await self._word_cnt(comp.text)
|
||||||
i = math.log(wc + 1, self.log_base)
|
i = math.log(wc + 1, self.log_base)
|
||||||
return random.uniform(i, i + 0.5)
|
return random.uniform(i, i + 0.5)
|
||||||
@@ -72,15 +113,67 @@ class RespondStage(Stage):
|
|||||||
# random
|
# random
|
||||||
return random.uniform(self.interval[0], self.interval[1])
|
return random.uniform(self.interval[0], self.interval[1])
|
||||||
|
|
||||||
|
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
||||||
|
"""检查消息链是否为空
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain (list[BaseMessageComponent]): 包含消息对象的列表
|
||||||
|
"""
|
||||||
|
if not chain:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for comp in chain:
|
||||||
|
comp_type = type(comp)
|
||||||
|
|
||||||
|
# 检查组件类型是否在字典中
|
||||||
|
if comp_type in self._component_validators:
|
||||||
|
if self._component_validators[comp_type](comp):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||||
|
|
||||||
|
# 如果所有组件都为空
|
||||||
|
return True
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
|
return
|
||||||
|
|
||||||
if len(result.chain) > 0:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
# 流式结果直接交付平台适配器处理
|
||||||
|
use_fallback = self.config.get("provider_settings", {}).get(
|
||||||
|
"streaming_segmented", False
|
||||||
|
)
|
||||||
|
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||||
await event._pre_send()
|
await event._pre_send()
|
||||||
|
await event.send_streaming(result.async_stream, use_fallback)
|
||||||
|
await event._post_send()
|
||||||
|
return
|
||||||
|
elif len(result.chain) > 0:
|
||||||
|
# 检查路径映射
|
||||||
|
if mappings := self.platform_settings.get("path_mapping", []):
|
||||||
|
for idx, component in enumerate(result.chain):
|
||||||
|
if isinstance(component, Comp.File) and component.file:
|
||||||
|
# 支持 File 消息段的路径映射。
|
||||||
|
component.file = path_Mapping(mappings, component.file)
|
||||||
|
event.get_result().chain[idx] = component
|
||||||
|
|
||||||
|
await event._pre_send()
|
||||||
|
|
||||||
|
# 检查消息链是否为空
|
||||||
|
try:
|
||||||
|
if await self._is_empty_message_chain(result.chain):
|
||||||
|
logger.info("消息为空,跳过发送阶段")
|
||||||
|
event.clear_result()
|
||||||
|
event.stop_event()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"空内容检查异常: {e}")
|
||||||
|
|
||||||
if self.enable_seg and (
|
if self.enable_seg and (
|
||||||
(self.only_llm_result and result.is_llm_result())
|
(self.only_llm_result and result.is_llm_result())
|
||||||
@@ -89,13 +182,13 @@ class RespondStage(Stage):
|
|||||||
decorated_comps = []
|
decorated_comps = []
|
||||||
if self.reply_with_mention:
|
if self.reply_with_mention:
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, At):
|
if isinstance(comp, Comp.At):
|
||||||
decorated_comps.append(comp)
|
decorated_comps.append(comp)
|
||||||
result.chain.remove(comp)
|
result.chain.remove(comp)
|
||||||
break
|
break
|
||||||
if self.reply_with_quote:
|
if self.reply_with_quote:
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Reply):
|
if isinstance(comp, Comp.Reply):
|
||||||
decorated_comps.append(comp)
|
decorated_comps.append(comp)
|
||||||
result.chain.remove(comp)
|
result.chain.remove(comp)
|
||||||
break
|
break
|
||||||
@@ -103,16 +196,24 @@ class RespondStage(Stage):
|
|||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
i = await self._calc_comp_interval(comp)
|
i = await self._calc_comp_interval(comp)
|
||||||
await asyncio.sleep(i)
|
await asyncio.sleep(i)
|
||||||
await event.send(MessageChain([*decorated_comps, comp]))
|
try:
|
||||||
|
await event.send(MessageChain([*decorated_comps, comp]))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
await event.send(result)
|
try:
|
||||||
|
await event.send(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
await event._post_send()
|
await event._post_send()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnAfterMessageSentEvent
|
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator
|
|||||||
from ..stage import Stage, register_stage, registered_stages
|
from ..stage import Stage, register_stage, registered_stages
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.message.message_event_result import ResultContentType
|
||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||||
@@ -31,6 +32,8 @@ class ResultDecorateStage(Stage):
|
|||||||
self.t2i_word_threshold = 50
|
self.t2i_word_threshold = 50
|
||||||
except BaseException:
|
except BaseException:
|
||||||
self.t2i_word_threshold = 150
|
self.t2i_word_threshold = 150
|
||||||
|
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
||||||
|
self.t2i_use_network = self.t2i_strategy == "remote"
|
||||||
|
|
||||||
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
||||||
"forward_threshold"
|
"forward_threshold"
|
||||||
@@ -70,11 +73,17 @@ class ResultDecorateStage(Stage):
|
|||||||
if result is None or not result.chain:
|
if result is None or not result.chain:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
|
||||||
|
|
||||||
# 回复时检查内容安全
|
# 回复时检查内容安全
|
||||||
if (
|
if (
|
||||||
self.content_safe_check_reply
|
self.content_safe_check_reply
|
||||||
and self.content_safe_check_stage
|
and self.content_safe_check_stage
|
||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
|
and not is_stream # 流式输出不检查内容安全
|
||||||
):
|
):
|
||||||
text = ""
|
text = ""
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
@@ -87,13 +96,17 @@ class ResultDecorateStage(Stage):
|
|||||||
|
|
||||||
# 发送消息前事件钩子
|
# 发送消息前事件钩子
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnDecoratingResultEvent
|
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
)
|
)
|
||||||
|
if is_stream:
|
||||||
|
logger.warning(
|
||||||
|
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
|
||||||
|
)
|
||||||
await handler.handler(event)
|
await handler.handler(event)
|
||||||
if event.get_result() is None or not event.get_result().chain:
|
if event.get_result() is None or not event.get_result().chain:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -108,6 +121,11 @@ class ResultDecorateStage(Stage):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 流式输出不执行下面的逻辑
|
||||||
|
if is_stream:
|
||||||
|
logger.info("流式输出已启用,跳过结果装饰阶段")
|
||||||
|
return
|
||||||
|
|
||||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
@@ -133,9 +151,9 @@ class ResultDecorateStage(Stage):
|
|||||||
# 不分段回复
|
# 不分段回复
|
||||||
new_chain.append(comp)
|
new_chain.append(comp)
|
||||||
continue
|
continue
|
||||||
split_response = []
|
split_response = re.findall(
|
||||||
for line in comp.text.split("\n"):
|
self.regex, comp.text, re.DOTALL | re.MULTILINE
|
||||||
split_response.extend(re.findall(self.regex, line))
|
)
|
||||||
if not split_response:
|
if not split_response:
|
||||||
new_chain.append(comp)
|
new_chain.append(comp)
|
||||||
continue
|
continue
|
||||||
@@ -166,6 +184,8 @@ class ResultDecorateStage(Stage):
|
|||||||
new_chain.append(
|
new_chain.append(
|
||||||
Record(file=audio_path, url=audio_path)
|
Record(file=audio_path, url=audio_path)
|
||||||
)
|
)
|
||||||
|
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
|
||||||
|
new_chain.append(comp)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||||
@@ -192,7 +212,9 @@ class ResultDecorateStage(Stage):
|
|||||||
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
||||||
render_start = time.time()
|
render_start = time.time()
|
||||||
try:
|
try:
|
||||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
url = await html_renderer.render_t2i(
|
||||||
|
plain_str, return_url=True, use_network=self.t2i_use_network
|
||||||
|
)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error("文本转图片失败,使用文本发送。")
|
logger.error("文本转图片失败,使用文本发送。")
|
||||||
return
|
return
|
||||||
@@ -201,7 +223,10 @@ class ResultDecorateStage(Stage):
|
|||||||
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
|
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
|
||||||
)
|
)
|
||||||
if url:
|
if url:
|
||||||
result.chain = [Image.fromURL(url)]
|
if url.startswith("http"):
|
||||||
|
result.chain = [Image.fromURL(url)]
|
||||||
|
else:
|
||||||
|
result.chain = [Image.fromFileSystem(url)]
|
||||||
|
|
||||||
# 触发转发消息
|
# 触发转发消息
|
||||||
has_forwarded = False
|
has_forwarded = False
|
||||||
|
|||||||
@@ -7,49 +7,72 @@ from astrbot.core import logger
|
|||||||
|
|
||||||
|
|
||||||
class PipelineScheduler:
|
class PipelineScheduler:
|
||||||
|
"""管道调度器,负责调度各个阶段的执行"""
|
||||||
|
|
||||||
def __init__(self, context: PipelineContext):
|
def __init__(self, context: PipelineContext):
|
||||||
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__))
|
registered_stages.sort(
|
||||||
self.ctx = context
|
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||||
|
) # 按照顺序排序
|
||||||
|
self.ctx = context # 上下文对象
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
"""初始化管道调度器时, 初始化所有阶段"""
|
||||||
for stage in registered_stages:
|
for stage in registered_stages:
|
||||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||||
|
|
||||||
await stage.initialize(self.ctx)
|
await stage.initialize(self.ctx)
|
||||||
|
|
||||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||||
|
"""依次执行各个阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||||
|
"""
|
||||||
for i in range(from_stage, len(registered_stages)):
|
for i in range(from_stage, len(registered_stages)):
|
||||||
stage = registered_stages[i]
|
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||||
coro = stage.process(event)
|
coroutine = stage.process(
|
||||||
if isinstance(coro, AsyncGenerator):
|
event
|
||||||
async for _ in coro:
|
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||||
|
|
||||||
|
if isinstance(coroutine, AsyncGenerator):
|
||||||
|
# 如果返回的是异步生成器, 实现洋葱模型的核心
|
||||||
|
async for _ in coroutine:
|
||||||
|
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 递归调用, 处理所有后续阶段
|
||||||
await self._process_stages(event, i + 1)
|
await self._process_stages(event, i + 1)
|
||||||
|
|
||||||
|
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
await coro
|
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
|
||||||
|
# 简单地等待它执行完成, 然后继续执行下一个阶段
|
||||||
|
await coroutine
|
||||||
|
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
||||||
break
|
break
|
||||||
|
|
||||||
if event.is_stopped():
|
|
||||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
|
||||||
break
|
|
||||||
|
|
||||||
async def execute(self, event: AstrMessageEvent):
|
async def execute(self, event: AstrMessageEvent):
|
||||||
"""执行 pipeline"""
|
"""执行 pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
"""
|
||||||
await self._process_stages(event)
|
await self._process_stages(event)
|
||||||
|
|
||||||
|
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||||
await event.send(None)
|
await event.send(None)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import inspect
|
import inspect
|
||||||
|
import traceback
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from typing import List, AsyncGenerator, Union, Awaitable
|
from typing import List, AsyncGenerator, Union, Awaitable
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from .context import PipelineContext
|
from .context import PipelineContext
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||||
|
|
||||||
registered_stages: List[Stage] = []
|
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||||
"""维护了所有已注册的 Stage 实现类"""
|
|
||||||
|
|
||||||
|
|
||||||
def register_stage(cls):
|
def register_stage(cls):
|
||||||
@@ -22,14 +22,24 @@ class Stage(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
"""初始化阶段"""
|
"""初始化阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""处理事件"""
|
"""处理事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象,包含事件的相关信息
|
||||||
|
Returns:
|
||||||
|
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _call_handler(
|
async def _call_handler(
|
||||||
@@ -40,33 +50,61 @@ class Stage(abc.ABC):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[None, None]:
|
) -> AsyncGenerator[None, None]:
|
||||||
"""调用 Handler。"""
|
"""执行事件处理函数并处理其返回结果
|
||||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
|
||||||
ready_to_call = None
|
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||||
|
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||||
|
2. 协程: 执行一次并处理返回值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象
|
||||||
|
event (AstrMessageEvent): 待处理的事件对象
|
||||||
|
handler (Awaitable): 事件处理函数
|
||||||
|
*args: 传递给handler的位置参数
|
||||||
|
**kwargs: 传递给handler的关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||||
|
"""
|
||||||
|
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||||
|
|
||||||
|
trace_ = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ready_to_call = handler(event, *args, **kwargs)
|
ready_to_call = handler(event, *args, **kwargs)
|
||||||
except TypeError as e:
|
except TypeError as _:
|
||||||
# 向下兼容
|
# 向下兼容
|
||||||
logger.debug(str(e))
|
trace_ = traceback.format_exc()
|
||||||
|
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(ready_to_call, AsyncGenerator):
|
if isinstance(ready_to_call, AsyncGenerator):
|
||||||
_has_yielded = False
|
# 如果是一个异步生成器, 进入洋葱模型
|
||||||
async for ret in ready_to_call:
|
_has_yielded = False # 是否返回过值
|
||||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
try:
|
||||||
_has_yielded = True
|
async for ret in ready_to_call:
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||||
event.set_result(ret)
|
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||||
|
_has_yielded = True
|
||||||
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||||
|
event.set_result(ret)
|
||||||
|
yield # 传递控制权给上一层的process函数
|
||||||
|
else:
|
||||||
|
# 如果返回值是 None, 则不设置结果并继续
|
||||||
|
# 继续执行后续阶段
|
||||||
|
yield ret # 传递控制权给上一层的process函数
|
||||||
|
if not _has_yielded:
|
||||||
|
# 如果这个异步生成器没有执行到yield分支
|
||||||
yield
|
yield
|
||||||
else:
|
except Exception as e:
|
||||||
yield ret
|
logger.error(f"Previous Error: {trace_}")
|
||||||
if not _has_yielded:
|
raise e
|
||||||
yield
|
|
||||||
elif inspect.iscoroutine(ready_to_call):
|
elif inspect.iscoroutine(ready_to_call):
|
||||||
# 如果只是一个 coroutine
|
# 如果只是一个协程, 直接执行
|
||||||
ret = await ready_to_call
|
ret = await ready_to_call
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
event.set_result(ret)
|
event.set_result(ret)
|
||||||
yield
|
yield # 传递控制权给上一层的process函数
|
||||||
else:
|
else:
|
||||||
yield ret
|
yield ret # 传递控制权给上一层的process函数
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from ..stage import Stage, register_stage
|
from ..stage import Stage, register_stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
|
from astrbot import logger
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||||
@@ -21,18 +22,38 @@ class WakingCheckStage(Stage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
"""初始化唤醒检查阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
||||||
"no_permission_reply", True
|
"no_permission_reply", True
|
||||||
)
|
)
|
||||||
|
# 私聊是否需要 wake_prefix 才能唤醒机器人
|
||||||
|
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
|
||||||
|
"platform_settings"
|
||||||
|
].get("friend_message_needs_wake_prefix", False)
|
||||||
|
# 是否忽略机器人自己发送的消息
|
||||||
|
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||||
|
"ignore_bot_self_message", False
|
||||||
|
)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
if (
|
||||||
|
self.ignore_bot_self_message
|
||||||
|
and event.get_self_id() == event.get_sender_id()
|
||||||
|
):
|
||||||
|
# 忽略机器人自己发送的消息
|
||||||
|
event.stop_event()
|
||||||
|
return
|
||||||
# 设置 sender 身份
|
# 设置 sender 身份
|
||||||
event.message_str = event.message_str.strip()
|
event.message_str = event.message_str.strip()
|
||||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||||
if event.get_sender_id() == admin_id:
|
if str(event.get_sender_id()) == admin_id:
|
||||||
event.role = "admin"
|
event.role = "admin"
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -68,7 +89,7 @@ class WakingCheckStage(Stage):
|
|||||||
event.is_at_or_wake_command = True
|
event.is_at_or_wake_command = True
|
||||||
break
|
break
|
||||||
# 检查是否是私聊
|
# 检查是否是私聊
|
||||||
if event.is_private_chat():
|
if event.is_private_chat() and not self.friend_message_needs_wake_prefix:
|
||||||
is_wake = True
|
is_wake = True
|
||||||
event.is_wake = True
|
event.is_wake = True
|
||||||
event.is_at_or_wake_command = True
|
event.is_at_or_wake_command = True
|
||||||
@@ -84,6 +105,7 @@ class WakingCheckStage(Stage):
|
|||||||
# filter 需满足 AND 逻辑关系
|
# filter 需满足 AND 逻辑关系
|
||||||
passed = True
|
passed = True
|
||||||
permission_not_pass = False
|
permission_not_pass = False
|
||||||
|
permission_filter_raise_error = False
|
||||||
if len(handler.event_filters) == 0:
|
if len(handler.event_filters) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -92,6 +114,7 @@ class WakingCheckStage(Stage):
|
|||||||
if isinstance(filter, PermissionTypeFilter):
|
if isinstance(filter, PermissionTypeFilter):
|
||||||
if not filter.filter(event, self.ctx.astrbot_config):
|
if not filter.filter(event, self.ctx.astrbot_config):
|
||||||
permission_not_pass = True
|
permission_not_pass = True
|
||||||
|
permission_filter_raise_error = filter.raise_error
|
||||||
else:
|
else:
|
||||||
if not filter.filter(event, self.ctx.astrbot_config):
|
if not filter.filter(event, self.ctx.astrbot_config):
|
||||||
passed = False
|
passed = False
|
||||||
@@ -102,17 +125,25 @@ class WakingCheckStage(Stage):
|
|||||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
await event._post_send()
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
passed = False
|
passed = False
|
||||||
break
|
break
|
||||||
if passed:
|
if passed:
|
||||||
if permission_not_pass:
|
if permission_not_pass:
|
||||||
|
if not permission_filter_raise_error:
|
||||||
|
# 跳过
|
||||||
|
continue
|
||||||
if self.no_permission_reply:
|
if self.no_permission_reply:
|
||||||
await event.send(
|
await event.send(
|
||||||
MessageChain().message(
|
MessageChain().message(
|
||||||
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
await event._post_send()
|
||||||
|
logger.info(
|
||||||
|
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||||
|
)
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ class WhitelistCheckStage(Stage):
|
|||||||
"enable_id_white_list"
|
"enable_id_white_list"
|
||||||
]
|
]
|
||||||
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
|
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
|
||||||
|
self.whitelist = [
|
||||||
|
str(i).strip() for i in self.whitelist if str(i).strip() != ""
|
||||||
|
]
|
||||||
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
|
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
|
||||||
"wl_ignore_admin_on_group"
|
"wl_ignore_admin_on_group"
|
||||||
]
|
]
|
||||||
@@ -51,7 +54,10 @@ class WhitelistCheckStage(Stage):
|
|||||||
and event.get_message_type() == MessageType.FRIEND_MESSAGE
|
and event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
if event.unified_msg_origin not in self.whitelist:
|
if (
|
||||||
|
event.unified_msg_origin not in self.whitelist
|
||||||
|
and str(event.get_group_id()).strip() not in self.whitelist
|
||||||
|
):
|
||||||
if self.wl_log:
|
if self.wl_log:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
|
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .platform import Platform
|
from .platform import Platform
|
||||||
from .astr_message_event import AstrMessageEvent
|
from .astr_message_event import AstrMessageEvent
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
from .astrbot_message import AstrBotMessage, MessageMember, MessageType
|
from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Platform",
|
"Platform",
|
||||||
@@ -10,4 +10,5 @@ __all__ = [
|
|||||||
"AstrBotMessage",
|
"AstrBotMessage",
|
||||||
"MessageMember",
|
"MessageMember",
|
||||||
"MessageType",
|
"MessageType",
|
||||||
|
"Group",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .astrbot_message import AstrBotMessage
|
from typing import List, Union, Optional, AsyncGenerator
|
||||||
from .platform_metadata import PlatformMetadata
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
from astrbot.core.db.po import Conversation
|
||||||
from astrbot.core.platform.message_type import MessageType
|
|
||||||
from typing import List, Union
|
|
||||||
from astrbot.core.message.components import (
|
from astrbot.core.message.components import (
|
||||||
Plain,
|
Plain,
|
||||||
Image,
|
Image,
|
||||||
@@ -13,10 +15,14 @@ from astrbot.core.message.components import (
|
|||||||
At,
|
At,
|
||||||
AtAll,
|
AtAll,
|
||||||
Forward,
|
Forward,
|
||||||
|
Reply,
|
||||||
)
|
)
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||||
|
from astrbot.core.platform.message_type import MessageType
|
||||||
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from .astrbot_message import AstrBotMessage, Group
|
||||||
from astrbot.core.db.po import Conversation
|
from .platform_metadata import PlatformMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -78,6 +84,9 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
def get_platform_name(self):
|
def get_platform_name(self):
|
||||||
return self.platform_meta.name
|
return self.platform_meta.name
|
||||||
|
|
||||||
|
def get_platform_id(self):
|
||||||
|
return self.platform_meta.id
|
||||||
|
|
||||||
def get_message_str(self) -> str:
|
def get_message_str(self) -> str:
|
||||||
"""
|
"""
|
||||||
获取消息字符串。
|
获取消息字符串。
|
||||||
@@ -100,8 +109,15 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
elif isinstance(i, Forward):
|
elif isinstance(i, Forward):
|
||||||
# 转发消息
|
# 转发消息
|
||||||
outline += "[转发消息]"
|
outline += "[转发消息]"
|
||||||
|
elif isinstance(i, Reply):
|
||||||
|
# 引用回复
|
||||||
|
if i.message_str:
|
||||||
|
outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
|
||||||
|
else:
|
||||||
|
outline += "[引用消息]"
|
||||||
else:
|
else:
|
||||||
outline += f"[{i.type}]"
|
outline += f"[{i.type}]"
|
||||||
|
outline += " "
|
||||||
return outline
|
return outline
|
||||||
|
|
||||||
def get_message_outline(self) -> str:
|
def get_message_outline(self) -> str:
|
||||||
@@ -192,11 +208,30 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
return self.role == "admin"
|
return self.role == "admin"
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
|
||||||
"""
|
"""
|
||||||
发送消息到消息平台。
|
将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。
|
||||||
"""
|
"""
|
||||||
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
while True:
|
||||||
|
match = re.search(pattern, buffer)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
matched_text = match.group()
|
||||||
|
await self.send(MessageChain([Plain(matched_text)]))
|
||||||
|
buffer = buffer[match.end() :]
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||||
|
):
|
||||||
|
"""发送流式消息到消息平台,使用异步生成器。
|
||||||
|
目前仅支持: telegram,qq official 私聊。
|
||||||
|
Fallback仅支持 aiocqhttp, gewechat。
|
||||||
|
"""
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||||
|
)
|
||||||
self._has_send_oper = True
|
self._has_send_oper = True
|
||||||
|
|
||||||
async def _pre_send(self):
|
async def _pre_send(self):
|
||||||
@@ -360,3 +395,31 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""平台适配器"""
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
"""发送消息到消息平台。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (MessageChain): 消息链,具体使用方式请参考文档。
|
||||||
|
"""
|
||||||
|
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
|
||||||
|
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
|
||||||
|
sid = str(uuid.UUID(bytes=hash_obj.digest()))
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._has_send_oper = True
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||||
|
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||||
|
|
||||||
|
适配情况:
|
||||||
|
|
||||||
|
- gewechat
|
||||||
|
- aiocqhttp(OneBotv11)
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -10,6 +10,41 @@ class MessageMember:
|
|||||||
user_id: str # 发送者id
|
user_id: str # 发送者id
|
||||||
nickname: str = None
|
nickname: str = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# 使用 f-string 来构建返回的字符串表示形式
|
||||||
|
return (
|
||||||
|
f"User ID: {self.user_id},"
|
||||||
|
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Group:
|
||||||
|
group_id: str
|
||||||
|
"""群号"""
|
||||||
|
group_name: str = None
|
||||||
|
"""群名称"""
|
||||||
|
group_avatar: str = None
|
||||||
|
"""群头像"""
|
||||||
|
group_owner: str = None
|
||||||
|
"""群主 id"""
|
||||||
|
group_admins: List[str] = None
|
||||||
|
"""群管理员 id"""
|
||||||
|
members: List[MessageMember] = None
|
||||||
|
"""所有群成员"""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# 使用 f-string 来构建返回的字符串表示形式
|
||||||
|
return (
|
||||||
|
f"Group ID: {self.group_id}\n"
|
||||||
|
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
|
||||||
|
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
|
||||||
|
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
|
||||||
|
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
|
||||||
|
f"Members Len: {len(self.members) if self.members else 0}\n"
|
||||||
|
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AstrBotMessage:
|
class AstrBotMessage:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ class PlatformManager:
|
|||||||
)
|
)
|
||||||
case "lark":
|
case "lark":
|
||||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||||
|
case "dingtalk":
|
||||||
|
from .sources.dingtalk.dingtalk_adapter import (
|
||||||
|
DingtalkPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "telegram":
|
case "telegram":
|
||||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||||
case "wecom":
|
case "wecom":
|
||||||
@@ -81,14 +85,18 @@ class PlatformManager:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
cls_type = platform_cls_map[platform_config["type"]]
|
cls_type = platform_cls_map[platform_config["type"]]
|
||||||
inst = cls_type(platform_config, self.settings, self.event_queue)
|
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
|
||||||
self._inst_map[platform_config["id"]] = inst
|
self._inst_map[platform_config["id"]] = {
|
||||||
|
"inst": inst,
|
||||||
|
"client_id": inst.client_self_id,
|
||||||
|
}
|
||||||
self.platform_insts.append(inst)
|
self.platform_insts.append(inst)
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self._task_wrapper(
|
self._task_wrapper(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
inst.run(), name=platform_config["id"] + "_platform"
|
inst.run(),
|
||||||
|
name=f"platform_{platform_config['type']}_{platform_config['id']}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -105,38 +113,42 @@ class PlatformManager:
|
|||||||
logger.error("-------")
|
logger.error("-------")
|
||||||
|
|
||||||
async def reload(self, platform_config: dict):
|
async def reload(self, platform_config: dict):
|
||||||
# 还未实现完成,不要调用此方法
|
await self.terminate_platform(platform_config["id"])
|
||||||
|
if platform_config["enable"]:
|
||||||
if platform_config["id"] in self._inst_map:
|
|
||||||
# 正在运行
|
|
||||||
if getattr(self._inst_map[platform_config["id"]], "terminate", None):
|
|
||||||
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
|
|
||||||
await self._inst_map[platform_config["id"]].terminate()
|
|
||||||
logger.info(f"{platform_config['id']} 平台适配器已终止。")
|
|
||||||
del self._inst_map[platform_config["id"]]
|
|
||||||
self.platform_insts.remove(self._inst_map[platform_config["id"]])
|
|
||||||
else:
|
|
||||||
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
|
|
||||||
|
|
||||||
# 再启动新的实例
|
|
||||||
await self.load_platform(platform_config)
|
await self.load_platform(platform_config)
|
||||||
|
|
||||||
else:
|
# 和配置文件保持同步
|
||||||
# 先将 _inst_map 中在 platform_config 中不存在的实例删除
|
config_ids = [provider["id"] for provider in self.platforms_config]
|
||||||
config_ids = [platform["id"] for platform in self.platforms_config]
|
for key in list(self._inst_map.keys()):
|
||||||
for key in list(self._inst_map.keys()):
|
if key not in config_ids:
|
||||||
if key not in config_ids:
|
await self.terminate_platform(key)
|
||||||
if getattr(self._inst_map[key], "terminate", None):
|
|
||||||
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
|
|
||||||
await self._inst_map[key].terminate()
|
|
||||||
logger.info(f"{key} 平台适配器已终止。")
|
|
||||||
del self._inst_map[key]
|
|
||||||
self.platform_insts.remove(self._inst_map[key])
|
|
||||||
else:
|
|
||||||
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
|
|
||||||
|
|
||||||
# 再启动新的实例
|
async def terminate_platform(self, platform_id: str):
|
||||||
await self.load_platform(platform_config)
|
if platform_id in self._inst_map:
|
||||||
|
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||||
|
|
||||||
|
# client_id = self._inst_map.pop(platform_id, None)
|
||||||
|
info = self._inst_map.pop(platform_id, None)
|
||||||
|
client_id = info["client_id"]
|
||||||
|
inst = info["inst"]
|
||||||
|
try:
|
||||||
|
self.platform_insts.remove(
|
||||||
|
next(
|
||||||
|
inst
|
||||||
|
for inst in self.platform_insts
|
||||||
|
if inst.client_self_id == client_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
|
||||||
|
|
||||||
|
if getattr(inst, "terminate", None):
|
||||||
|
await inst.terminate()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
for inst in self.platform_insts:
|
||||||
|
if getattr(inst, "terminate", None):
|
||||||
|
await inst.terminate()
|
||||||
|
|
||||||
def get_insts(self):
|
def get_insts(self):
|
||||||
return self.platform_insts
|
return self.platform_insts
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import uuid
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
@@ -13,6 +14,7 @@ class Platform(abc.ABC):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||||
self._event_queue = event_queue
|
self._event_queue = event_queue
|
||||||
|
self.client_self_id = uuid.uuid4().hex
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self) -> Awaitable[Any]:
|
def run(self) -> Awaitable[Any]:
|
||||||
@@ -25,7 +27,7 @@ class Platform(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
终止一个平台的运行实例。
|
终止一个平台的运行实例。
|
||||||
"""
|
"""
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ class PlatformMetadata:
|
|||||||
"""平台的名称"""
|
"""平台的名称"""
|
||||||
description: str
|
description: str
|
||||||
"""平台的描述"""
|
"""平台的描述"""
|
||||||
|
id: str = None
|
||||||
|
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||||
|
|
||||||
default_config_tmpl: dict = None
|
default_config_tmpl: dict = None
|
||||||
"""平台的默认配置模板"""
|
"""平台的默认配置模板"""
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from typing import AsyncGenerator, Dict, List
|
||||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
|
||||||
from aiocqhttp import CQHttp
|
from aiocqhttp import CQHttp
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record
|
||||||
|
from astrbot.api.platform import Group, MessageMember
|
||||||
|
|
||||||
|
|
||||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||||
@@ -21,20 +22,15 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
d = segment.toDict()
|
d = segment.toDict()
|
||||||
if isinstance(segment, Plain):
|
if isinstance(segment, Plain):
|
||||||
d["type"] = "text"
|
d["type"] = "text"
|
||||||
|
d["data"]["text"] = segment.text.strip()
|
||||||
|
# 如果是空文本或者只带换行符的文本,不发送
|
||||||
|
if not d["data"]["text"]:
|
||||||
|
continue
|
||||||
elif isinstance(segment, (Image, Record)):
|
elif isinstance(segment, (Image, Record)):
|
||||||
# convert to base64
|
# convert to base64
|
||||||
if segment.file and segment.file.startswith("file:///"):
|
bs64 = await segment.convert_to_base64()
|
||||||
bs64_data = file_to_base64(segment.file[8:])
|
|
||||||
image_file_path = segment.file[8:]
|
|
||||||
elif segment.file and segment.file.startswith("http"):
|
|
||||||
image_file_path = await download_image_by_url(segment.file)
|
|
||||||
bs64_data = file_to_base64(image_file_path)
|
|
||||||
elif segment.file and segment.file.startswith("base64://"):
|
|
||||||
bs64_data = segment.file
|
|
||||||
else:
|
|
||||||
bs64_data = file_to_base64(segment.file)
|
|
||||||
d["data"] = {
|
d["data"] = {
|
||||||
"file": bs64_data,
|
"file": f"base64://{bs64}",
|
||||||
}
|
}
|
||||||
elif isinstance(segment, At):
|
elif isinstance(segment, At):
|
||||||
d["data"] = {
|
d["data"] = {
|
||||||
@@ -46,6 +42,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
return
|
||||||
|
|
||||||
send_one_by_one = False
|
send_one_by_one = False
|
||||||
for seg in message.chain:
|
for seg in message.chain:
|
||||||
if isinstance(seg, (Node, Nodes)):
|
if isinstance(seg, (Node, Nodes)):
|
||||||
@@ -55,8 +54,13 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
if send_one_by_one:
|
if send_one_by_one:
|
||||||
for seg in message.chain:
|
for seg in message.chain:
|
||||||
if isinstance(seg, Nodes):
|
if isinstance(seg, (Node, Nodes)):
|
||||||
# 带有多个节点的合并转发消息
|
# 合并转发消息
|
||||||
|
|
||||||
|
if isinstance(seg, Node):
|
||||||
|
nodes = Nodes([seg])
|
||||||
|
seg = nodes
|
||||||
|
|
||||||
payload = seg.toDict()
|
payload = seg.toDict()
|
||||||
if self.get_group_id():
|
if self.get_group_id():
|
||||||
payload["group_id"] = self.get_group_id()
|
payload["group_id"] = self.get_group_id()
|
||||||
@@ -78,3 +82,80 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
await self.bot.send(self.message_obj.raw_message, ret)
|
await self.bot.send(self.message_obj.raw_message, ret)
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||||
|
):
|
||||||
|
if not use_fallback:
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
buffer += comp.text
|
||||||
|
if any(p in buffer for p in "。?!~…"):
|
||||||
|
buffer = await self.process_buffer(buffer, pattern)
|
||||||
|
else:
|
||||||
|
await self.send(MessageChain(chain=[comp]))
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
|
||||||
|
if buffer.strip():
|
||||||
|
await self.send(MessageChain([Plain(buffer)]))
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
async def get_group(self, group_id=None, **kwargs):
|
||||||
|
if isinstance(group_id, str) and group_id.isdigit():
|
||||||
|
group_id = int(group_id)
|
||||||
|
elif self.get_group_id():
|
||||||
|
group_id = int(self.get_group_id())
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
info: dict = await self.bot.call_action(
|
||||||
|
"get_group_info",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
members: List[Dict] = await self.bot.call_action(
|
||||||
|
"get_group_member_list",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
owner_id = None
|
||||||
|
admin_ids = []
|
||||||
|
for member in members:
|
||||||
|
if member["role"] == "owner":
|
||||||
|
owner_id = member["user_id"]
|
||||||
|
if member["role"] == "admin":
|
||||||
|
admin_ids.append(member["user_id"])
|
||||||
|
|
||||||
|
group = Group(
|
||||||
|
group_id=str(group_id),
|
||||||
|
group_name=info.get("group_name"),
|
||||||
|
group_avatar="",
|
||||||
|
group_admins=admin_ids,
|
||||||
|
group_owner=str(owner_id),
|
||||||
|
members=[
|
||||||
|
MessageMember(
|
||||||
|
user_id=member["user_id"],
|
||||||
|
nickname=member.get("nickname") or member.get("card"),
|
||||||
|
)
|
||||||
|
for member in members
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return group
|
||||||
|
|||||||
@@ -39,12 +39,11 @@ class AiocqhttpAdapter(Platform):
|
|||||||
self.port = platform_config["ws_reverse_port"]
|
self.port = platform_config["ws_reverse_port"]
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
"aiocqhttp",
|
name="aiocqhttp",
|
||||||
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stop = False
|
|
||||||
|
|
||||||
self.bot = CQHttp(
|
self.bot = CQHttp(
|
||||||
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
||||||
)
|
)
|
||||||
@@ -111,7 +110,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
"""OneBot V11 请求类事件"""
|
"""OneBot V11 请求类事件"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||||
abm.type = MessageType.OTHER_MESSAGE
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
if "group_id" in event and event["group_id"]:
|
if "group_id" in event and event["group_id"]:
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
@@ -131,7 +130,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
"""OneBot V11 通知类事件"""
|
"""OneBot V11 通知类事件"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||||
abm.type = MessageType.OTHER_MESSAGE
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
if "group_id" in event and event["group_id"]:
|
if "group_id" in event and event["group_id"]:
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
@@ -140,7 +139,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
abm.session_id = (
|
abm.session_id = (
|
||||||
abm.sender.user_id + "_" + str(event.group_id)
|
str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||||
) # 也保留群组 id
|
) # 也保留群组 id
|
||||||
else:
|
else:
|
||||||
abm.session_id = (
|
abm.session_id = (
|
||||||
@@ -160,8 +159,14 @@ class AiocqhttpAdapter(Platform):
|
|||||||
|
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
|
async def _convert_handle_message_event(
|
||||||
"""OneBot V11 消息类事件"""
|
self, event: Event, get_reply=True
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""OneBot V11 消息类事件
|
||||||
|
|
||||||
|
@param event: 事件对象
|
||||||
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
|
"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
@@ -240,6 +245,36 @@ class AiocqhttpAdapter(Platform):
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||||
|
|
||||||
|
elif t == "reply":
|
||||||
|
if not get_reply:
|
||||||
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
|
abm.message.append(a)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
reply_event_data = await self.bot.call_action(
|
||||||
|
action="get_msg",
|
||||||
|
message_id=int(m["data"]["id"]),
|
||||||
|
)
|
||||||
|
abm_reply = await self._convert_handle_message_event(
|
||||||
|
Event.from_payload(reply_event_data), get_reply=False
|
||||||
|
)
|
||||||
|
|
||||||
|
reply_seg = Reply(
|
||||||
|
id=abm_reply.message_id,
|
||||||
|
chain=abm_reply.message,
|
||||||
|
sender_id=abm_reply.sender.user_id,
|
||||||
|
sender_nickname=abm_reply.sender.nickname,
|
||||||
|
time=abm_reply.timestamp,
|
||||||
|
message_str=abm_reply.message_str,
|
||||||
|
text=abm_reply.message_str, # for compatibility
|
||||||
|
qq=abm_reply.sender.user_id, # for compatibility
|
||||||
|
)
|
||||||
|
|
||||||
|
abm.message.append(reply_seg)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"获取引用消息失败: {e}。")
|
||||||
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
|
abm.message.append(a)
|
||||||
else:
|
else:
|
||||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
abm.message.append(a)
|
abm.message.append(a)
|
||||||
@@ -267,22 +302,19 @@ class AiocqhttpAdapter(Platform):
|
|||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
|
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
return coro
|
return coro
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
self.stop = True
|
self.shutdown_event.set()
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
async def shutdown_trigger_placeholder(self):
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
logger.info("aiocqhttp 适配器已被优雅地关闭")
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return self.metadata
|
return self.metadata
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
|
||||||
# TODO: use asyncio.Event
|
|
||||||
while not self._event_queue.closed and not self.stop: # noqa: ASYNC110
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("aiocqhttp 适配器已关闭。")
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
message_event = AiocqhttpMessageEvent(
|
message_event = AiocqhttpMessageEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
|
|||||||
228
astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Normal file
228
astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import aiohttp
|
||||||
|
import dingtalk_stream
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Image, Plain, At
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from .dingtalk_event import DingtalkMessageEvent
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from astrbot import logger
|
||||||
|
from dingtalk_stream import AckMessage
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
|
|
||||||
|
|
||||||
|
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||||
|
async def process(self, event: dingtalk_stream.EventMessage):
|
||||||
|
print(
|
||||||
|
"2",
|
||||||
|
event.headers.event_type,
|
||||||
|
event.headers.event_id,
|
||||||
|
event.headers.event_born_time,
|
||||||
|
event.data,
|
||||||
|
)
|
||||||
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
|
||||||
|
class DingtalkPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
|
||||||
|
self.config = platform_config
|
||||||
|
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
self.client_id = platform_config["client_id"]
|
||||||
|
self.client_secret = platform_config["client_secret"]
|
||||||
|
|
||||||
|
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||||
|
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||||
|
logger.debug(f"dingtalk: {message.data}")
|
||||||
|
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||||
|
abm = await self.convert_msg(im)
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
|
self.client = AstrCallbackClient()
|
||||||
|
|
||||||
|
credential = dingtalk_stream.Credential(self.client_id, self.client_secret)
|
||||||
|
client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger)
|
||||||
|
client.register_all_event_handler(MyEventHandler())
|
||||||
|
client.register_callback_handler(
|
||||||
|
dingtalk_stream.ChatbotMessage.TOPIC, self.client
|
||||||
|
)
|
||||||
|
self.client_ = client # 用于 websockets 的 client
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="dingtalk",
|
||||||
|
description="钉钉机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def convert_msg(
|
||||||
|
self, message: dingtalk_stream.ChatbotMessage
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.message = []
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.timestamp = int(message.create_at / 1000)
|
||||||
|
abm.type = (
|
||||||
|
MessageType.GROUP_MESSAGE
|
||||||
|
if message.conversation_type == "2"
|
||||||
|
else MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
user_id=message.sender_id, nickname=message.sender_nick
|
||||||
|
)
|
||||||
|
abm.self_id = message.chatbot_user_id
|
||||||
|
abm.message_id = message.message_id
|
||||||
|
abm.raw_message = message
|
||||||
|
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
if message.is_in_at_list:
|
||||||
|
abm.message.append(At(qq=abm.self_id))
|
||||||
|
abm.group_id = message.conversation_id
|
||||||
|
if self.unique_session:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.group_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
|
message_type: str = message.message_type
|
||||||
|
match message_type:
|
||||||
|
case "text":
|
||||||
|
abm.message_str = message.text.content.strip()
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
case "richText":
|
||||||
|
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||||
|
contents: list[dict] = rtc.rich_text_list
|
||||||
|
for content in contents:
|
||||||
|
plains = ""
|
||||||
|
if "text" in content:
|
||||||
|
plains += content["text"]
|
||||||
|
abm.message.append(Plain(plains))
|
||||||
|
elif "type" in content and content["type"] == "picture":
|
||||||
|
f_path = await self.download_ding_file(
|
||||||
|
content["downloadCode"],
|
||||||
|
message.robot_code,
|
||||||
|
"jpg",
|
||||||
|
)
|
||||||
|
abm.message.append(Image.fromFileSystem(f_path))
|
||||||
|
case "audio":
|
||||||
|
pass
|
||||||
|
|
||||||
|
return abm # 别忘了返回转换后的消息对象
|
||||||
|
|
||||||
|
async def download_ding_file(
|
||||||
|
self, download_code: str, robot_code: str, ext: str
|
||||||
|
) -> str:
|
||||||
|
"""下载钉钉文件
|
||||||
|
|
||||||
|
:param access_token: 钉钉机器人的 access_token
|
||||||
|
:param download_code: 下载码
|
||||||
|
:param robot_code: 机器人码
|
||||||
|
:param ext: 文件后缀
|
||||||
|
:return: 文件路径
|
||||||
|
"""
|
||||||
|
access_token = await self.get_access_token()
|
||||||
|
headers = {
|
||||||
|
"x-acs-dingtalk-access-token": access_token,
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"downloadCode": download_code,
|
||||||
|
"robotCode": robot_code,
|
||||||
|
}
|
||||||
|
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"下载钉钉文件失败: {resp.status}, {await resp.text()}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
resp_data = await resp.json()
|
||||||
|
download_url = resp_data["data"]["downloadUrl"]
|
||||||
|
await download_file(download_url, f_path)
|
||||||
|
return f_path
|
||||||
|
|
||||||
|
async def get_access_token(self) -> str:
|
||||||
|
payload = {
|
||||||
|
"appKey": self.client_id,
|
||||||
|
"appSecret": self.client_secret,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return (await resp.json())["data"]["accessToken"]
|
||||||
|
|
||||||
|
async def handle_msg(self, abm: AstrBotMessage):
|
||||||
|
event = DingtalkMessageEvent(
|
||||||
|
message_str=abm.message_str,
|
||||||
|
message_obj=abm,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=abm.session_id,
|
||||||
|
client=self.client,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_queue.put_nowait(event)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
# await self.client_.start()
|
||||||
|
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
|
||||||
|
def start_client(loop: asyncio.AbstractEventLoop):
|
||||||
|
try:
|
||||||
|
self._shutdown_event = threading.Event()
|
||||||
|
task = loop.create_task(self.client_.start())
|
||||||
|
self._shutdown_event.wait()
|
||||||
|
if task.done():
|
||||||
|
task.result()
|
||||||
|
except Exception as e:
|
||||||
|
if "Graceful shutdown" in str(e):
|
||||||
|
logger.info("钉钉适配器已被优雅地关闭")
|
||||||
|
return
|
||||||
|
logger.error(f"钉钉机器人启动失败: {e}")
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, start_client, loop)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
def monkey_patch_close():
|
||||||
|
raise Exception("Graceful shutdown")
|
||||||
|
|
||||||
|
self.client_.open_connection = monkey_patch_close
|
||||||
|
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
return self.client
|
||||||
75
astrbot/core/platform/sources/dingtalk/dingtalk_event.py
Normal file
75
astrbot/core/platform/sources/dingtalk/dingtalk_event.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import asyncio
|
||||||
|
import dingtalk_stream
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
|
class DingtalkMessageEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str,
|
||||||
|
message_obj,
|
||||||
|
platform_meta,
|
||||||
|
session_id,
|
||||||
|
client: dingtalk_stream.ChatbotHandler,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def send_with_client(
|
||||||
|
self, client: dingtalk_stream.ChatbotHandler, message: MessageChain
|
||||||
|
):
|
||||||
|
for segment in message.chain:
|
||||||
|
if isinstance(segment, Comp.Plain):
|
||||||
|
segment.text = segment.text.strip()
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
client.reply_markdown,
|
||||||
|
"AstrBot",
|
||||||
|
segment.text,
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
)
|
||||||
|
elif isinstance(segment, Comp.Image):
|
||||||
|
markdown_str = ""
|
||||||
|
if segment.file and segment.file.startswith("file:///"):
|
||||||
|
logger.warning(
|
||||||
|
"dingtalk only support url image, not: " + segment.file
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif segment.file and segment.file.startswith("http"):
|
||||||
|
markdown_str += f"\n\n"
|
||||||
|
elif segment.file and segment.file.startswith("base64://"):
|
||||||
|
logger.warning("dingtalk only support url image, not base64")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"dingtalk only support url image, not: " + segment.file
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ret = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
client.reply_markdown,
|
||||||
|
"😄",
|
||||||
|
markdown_str,
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
)
|
||||||
|
logger.debug(f"send image: {ret}")
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
await self.send_with_client(self.client, message)
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
@@ -1,17 +1,26 @@
|
|||||||
import threading
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
|
||||||
import quart
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import anyio
|
import anyio
|
||||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
import quart
|
||||||
from astrbot.api.message_components import Plain, Image, At, Record
|
|
||||||
from astrbot.api import logger, sp
|
from astrbot.api import logger, sp
|
||||||
from .downloader import GeweDownloader
|
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||||
|
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from .downloader import GeweDownloader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .xml_data_parser import GeweDataParser
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
logger.warning(
|
||||||
|
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleGewechatClient:
|
class SimpleGewechatClient:
|
||||||
@@ -51,11 +60,11 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"]
|
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||||
)
|
)
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/astrbot-gewechat/file/<file_id>",
|
"/astrbot-gewechat/file/<file_id>",
|
||||||
view_func=self.handle_file,
|
view_func=self._handle_file,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,9 +79,10 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
self.userrealnames = {}
|
self.userrealnames = {}
|
||||||
|
|
||||||
self.stop = False
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def get_token_id(self):
|
async def get_token_id(self):
|
||||||
|
"""获取 Gewechat Token。"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
@@ -87,6 +97,15 @@ class SimpleGewechatClient:
|
|||||||
type_name = data["type_name"]
|
type_name = data["type_name"]
|
||||||
else:
|
else:
|
||||||
raise Exception("无法识别的消息类型")
|
raise Exception("无法识别的消息类型")
|
||||||
|
|
||||||
|
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||||
|
if type_name == "ModContacts":
|
||||||
|
logger.info("gewechat下发:ModContacts消息通知。")
|
||||||
|
return
|
||||||
|
if type_name == "DelContacts":
|
||||||
|
logger.info("gewechat下发:DelContacts消息通知。")
|
||||||
|
return
|
||||||
|
|
||||||
if type_name == "Offline":
|
if type_name == "Offline":
|
||||||
logger.critical("收到 gewechat 下线通知。")
|
logger.critical("收到 gewechat 下线通知。")
|
||||||
return
|
return
|
||||||
@@ -147,6 +166,11 @@ class SimpleGewechatClient:
|
|||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
user_id = from_user_name
|
user_id = from_user_name
|
||||||
|
|
||||||
|
# 检查消息是否由自己发送,若是则忽略
|
||||||
|
if user_id == abm.self_id:
|
||||||
|
logger.info("忽略自己发送的消息")
|
||||||
|
return None
|
||||||
|
|
||||||
abm.message = []
|
abm.message = []
|
||||||
if at_me:
|
if at_me:
|
||||||
abm.message.insert(0, At(qq=abm.self_id))
|
abm.message.insert(0, At(qq=abm.self_id))
|
||||||
@@ -178,6 +202,11 @@ class SimpleGewechatClient:
|
|||||||
abm.sender = MessageMember(user_id, user_real_name)
|
abm.sender = MessageMember(user_id, user_real_name)
|
||||||
abm.raw_message = d
|
abm.raw_message = d
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
|
|
||||||
|
if user_id == "weixin":
|
||||||
|
# 忽略微信团队消息
|
||||||
|
return
|
||||||
|
|
||||||
# 不同消息类型
|
# 不同消息类型
|
||||||
match d["MsgType"]:
|
match d["MsgType"]:
|
||||||
case 1:
|
case 1:
|
||||||
@@ -195,18 +224,42 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
case 34:
|
case 34:
|
||||||
# 语音消息
|
# 语音消息
|
||||||
# data = await self.multimedia_downloader.download_voice(
|
|
||||||
# self.appid,
|
|
||||||
# content,
|
|
||||||
# abm.message_id
|
|
||||||
# )
|
|
||||||
# print(data)
|
|
||||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||||
|
|
||||||
async with await anyio.open_file(file_path, "wb") as f:
|
async with await anyio.open_file(file_path, "wb") as f:
|
||||||
await f.write(voice_data)
|
await f.write(voice_data)
|
||||||
abm.message.append(Record(file=file_path, url=file_path))
|
abm.message.append(Record(file=file_path, url=file_path))
|
||||||
|
|
||||||
|
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||||
|
case 37: # 好友申请
|
||||||
|
logger.info("消息类型(37):好友申请")
|
||||||
|
case 42: # 名片
|
||||||
|
logger.info("消息类型(42):名片")
|
||||||
|
case 43: # 视频
|
||||||
|
video = Video(file="", cover=content)
|
||||||
|
abm.message.append(video)
|
||||||
|
case 47: # emoji
|
||||||
|
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||||
|
emoji = data_parser.parse_emoji()
|
||||||
|
abm.message.append(emoji)
|
||||||
|
case 48: # 地理位置
|
||||||
|
logger.info("消息类型(48):地理位置")
|
||||||
|
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||||
|
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||||
|
abm_data = data_parser.parse_mutil_49()
|
||||||
|
if abm_data:
|
||||||
|
abm.message.append(abm_data)
|
||||||
|
case 51: # 帐号消息同步?
|
||||||
|
logger.info("消息类型(51):帐号消息同步?")
|
||||||
|
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||||
|
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||||
|
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||||
|
logger.info(
|
||||||
|
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||||
abm.raw_message = d
|
abm.raw_message = d
|
||||||
@@ -214,7 +267,7 @@ class SimpleGewechatClient:
|
|||||||
logger.debug(f"abm: {abm}")
|
logger.debug(f"abm: {abm}")
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
async def callback(self):
|
async def _callback(self):
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
logger.debug(f"收到 gewechat 回调: {data}")
|
logger.debug(f"收到 gewechat 回调: {data}")
|
||||||
|
|
||||||
@@ -236,7 +289,7 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
return quart.jsonify({"r": "AstrBot ACK"})
|
return quart.jsonify({"r": "AstrBot ACK"})
|
||||||
|
|
||||||
async def handle_file(self, file_id):
|
async def _handle_file(self, file_id):
|
||||||
file_path = f"data/temp/{file_id}"
|
file_path = f"data/temp/{file_id}"
|
||||||
return await quart.send_file(file_path)
|
return await quart.send_file(file_path)
|
||||||
|
|
||||||
@@ -262,17 +315,14 @@ class SimpleGewechatClient:
|
|||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder,
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
# TODO: use asyncio.Event
|
await self.shutdown_event.wait()
|
||||||
while not self.event_queue.closed and not self.stop: # noqa: ASYNC110
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("gewechat 适配器已关闭。")
|
|
||||||
|
|
||||||
async def check_online(self, appid: str):
|
async def check_online(self, appid: str):
|
||||||
# /login/checkOnline
|
"""检查 APPID 对应的设备是否在线。"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/login/checkOnline",
|
f"{self.base_url}/login/checkOnline",
|
||||||
@@ -283,6 +333,7 @@ class SimpleGewechatClient:
|
|||||||
return json_blob["data"]
|
return json_blob["data"]
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
|
"""登出 gewechat。"""
|
||||||
if self.appid:
|
if self.appid:
|
||||||
online = await self.check_online(self.appid)
|
online = await self.check_online(self.appid)
|
||||||
if online:
|
if online:
|
||||||
@@ -296,6 +347,7 @@ class SimpleGewechatClient:
|
|||||||
logger.info(f"登出结果: {json_blob}")
|
logger.info(f"登出结果: {json_blob}")
|
||||||
|
|
||||||
async def login(self):
|
async def login(self):
|
||||||
|
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||||
if self.token is None:
|
if self.token is None:
|
||||||
await self.get_token_id()
|
await self.get_token_id()
|
||||||
|
|
||||||
@@ -304,32 +356,49 @@ class SimpleGewechatClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.appid:
|
if self.appid:
|
||||||
online = await self.check_online(self.appid)
|
try:
|
||||||
if online:
|
online = await self.check_online(self.appid)
|
||||||
logger.info(f"APPID: {self.appid} 已在线")
|
if online:
|
||||||
return
|
logger.info(f"APPID: {self.appid} 已在线")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查在线状态失败: {e}")
|
||||||
|
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||||
|
self.appid = None
|
||||||
|
|
||||||
payload = {"appId": self.appid}
|
payload = {"appId": self.appid}
|
||||||
|
|
||||||
if self.appid:
|
if self.appid:
|
||||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
try:
|
||||||
async with session.post(
|
async with aiohttp.ClientSession() as session:
|
||||||
f"{self.base_url}/login/getLoginQrCode",
|
async with session.post(
|
||||||
headers=self.headers,
|
f"{self.base_url}/login/getLoginQrCode",
|
||||||
json=payload,
|
headers=self.headers,
|
||||||
) as resp:
|
json=payload,
|
||||||
json_blob = await resp.json()
|
) as resp:
|
||||||
if json_blob["ret"] != 200:
|
json_blob = await resp.json()
|
||||||
raise Exception(f"获取二维码失败: {json_blob}")
|
if json_blob["ret"] != 200:
|
||||||
qr_data = json_blob["data"]["qrData"]
|
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||||
qr_uuid = json_blob["data"]["uuid"]
|
if "设备不存在" in error_msg:
|
||||||
appid = json_blob["data"]["appId"]
|
logger.error(
|
||||||
logger.info(f"APPID: {appid}")
|
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||||
logger.warning(
|
)
|
||||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||||
)
|
self.appid = None
|
||||||
|
return await self.login()
|
||||||
|
else:
|
||||||
|
raise Exception(f"获取二维码失败: {json_blob}")
|
||||||
|
qr_data = json_blob["data"]["qrData"]
|
||||||
|
qr_uuid = json_blob["data"]["uuid"]
|
||||||
|
appid = json_blob["data"]["appId"]
|
||||||
|
logger.info(f"APPID: {appid}")
|
||||||
|
logger.warning(
|
||||||
|
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
# 执行登录
|
# 执行登录
|
||||||
retry_cnt = 64
|
retry_cnt = 64
|
||||||
@@ -390,9 +459,18 @@ class SimpleGewechatClient:
|
|||||||
self.appid = appid
|
self.appid = appid
|
||||||
logger.info(f"已保存 APPID: {appid}")
|
logger.info(f"已保存 APPID: {appid}")
|
||||||
|
|
||||||
"""API"""
|
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||||
|
"""
|
||||||
|
|
||||||
async def get_chatroom_member_list(self, chatroom_wxid: str):
|
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||||
|
"""获取群成员列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||||
|
"""
|
||||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@@ -405,6 +483,7 @@ class SimpleGewechatClient:
|
|||||||
return json_blob["data"]
|
return json_blob["data"]
|
||||||
|
|
||||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||||
|
"""发送纯文本消息"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -421,6 +500,7 @@ class SimpleGewechatClient:
|
|||||||
logger.debug(f"发送消息结果: {json_blob}")
|
logger.debug(f"发送消息结果: {json_blob}")
|
||||||
|
|
||||||
async def post_image(self, to_wxid, image_url: str):
|
async def post_image(self, to_wxid, image_url: str):
|
||||||
|
"""发送图片消息"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -434,7 +514,79 @@ class SimpleGewechatClient:
|
|||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送图片结果: {json_blob}")
|
logger.debug(f"发送图片结果: {json_blob}")
|
||||||
|
|
||||||
|
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||||
|
"""发送emoji消息"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"emojiMd5": emoji_md5,
|
||||||
|
"emojiSize": emoji_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||||
|
try:
|
||||||
|
if emoji_md5 != "" and emoji_size != "":
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postEmoji",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.info(
|
||||||
|
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.post_image(to_wxid, cdnurl)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
async def post_video(
|
||||||
|
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"videoUrl": video_url,
|
||||||
|
"thumbUrl": thumb_url,
|
||||||
|
"videoDuration": video_duration,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"发送视频结果: {json_blob}")
|
||||||
|
|
||||||
|
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||||
|
"""转发视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_wxid (str): 发送给谁
|
||||||
|
cnd_xml (str): 视频消息的cdn信息
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"xml": cnd_xml,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/forwardVideo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"转发视频结果: {json_blob}")
|
||||||
|
|
||||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||||
|
"""发送语音信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_url (str): 语音文件的网络链接
|
||||||
|
voice_duration (int): 语音时长,毫秒
|
||||||
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -449,9 +601,16 @@ class SimpleGewechatClient:
|
|||||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送语音结果: {json_blob}")
|
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||||
|
|
||||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||||
|
"""发送文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_wxid (string): 微信ID
|
||||||
|
file_url (str): 文件的网络链接
|
||||||
|
file_name (str): 文件名
|
||||||
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -465,3 +624,131 @@ class SimpleGewechatClient:
|
|||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送文件结果: {json_blob}")
|
logger.debug(f"发送文件结果: {json_blob}")
|
||||||
|
|
||||||
|
async def add_friend(self, v3: str, v4: str, content: str):
|
||||||
|
"""申请添加好友"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"scene": 3,
|
||||||
|
"content": content,
|
||||||
|
"v4": v4,
|
||||||
|
"v3": v3,
|
||||||
|
"option": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/addContacts",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomInfo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_group_member(self, group_id: str):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomMemberList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def accept_group_invite(self, url: str):
|
||||||
|
"""同意进群"""
|
||||||
|
payload = {"appId": self.appid, "url": url}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/agreeJoinRoom",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def add_group_member_to_friend(
|
||||||
|
self, group_id: str, to_wxid: str, content: str
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
"content": content,
|
||||||
|
"memberWxid": to_wxid,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_user_or_group_info(self, *ids):
|
||||||
|
"""
|
||||||
|
获取用户或群组信息。
|
||||||
|
|
||||||
|
:param ids: 可变数量的 wxid 参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
wxids_str = list(ids)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/getDetailInfo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_contacts_list(self):
|
||||||
|
"""
|
||||||
|
获取通讯录列表
|
||||||
|
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||||
|
"""
|
||||||
|
payload = {"appId": self.appid}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/fetchContactsList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|||||||
@@ -39,3 +39,17 @@ class GeweDownloader:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
raise Exception("无法下载图片")
|
raise Exception("无法下载图片")
|
||||||
|
|
||||||
|
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||||
|
"""下载emoji"""
|
||||||
|
try:
|
||||||
|
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||||
|
|
||||||
|
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||||
|
data = await self._post_json(
|
||||||
|
self.base_url, "/message/downloadEmojiMd5", payload
|
||||||
|
)
|
||||||
|
json_blob = json.loads(data)
|
||||||
|
return json_blob
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"gewe download emoji: {e}")
|
||||||
|
|||||||
@@ -1,13 +1,25 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
import wave
|
import wave
|
||||||
import uuid
|
import uuid
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from astrbot.core.utils.io import save_temp_img, download_file
|
||||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||||
from astrbot.api.message_components import Plain, Image, Record, At, File
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
Record,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Video,
|
||||||
|
WechatEmoji as Emoji,
|
||||||
|
)
|
||||||
from .client import SimpleGewechatClient
|
from .client import SimpleGewechatClient
|
||||||
|
|
||||||
|
|
||||||
@@ -37,12 +49,9 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def send_with_client(message: MessageChain, user_name: str):
|
async def send_with_client(
|
||||||
pass
|
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
||||||
|
):
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
|
||||||
|
|
||||||
if not to_wxid:
|
if not to_wxid:
|
||||||
logger.error("无法获取到 to_wxid。")
|
logger.error("无法获取到 to_wxid。")
|
||||||
return
|
return
|
||||||
@@ -70,56 +79,94 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
payload["content"] = text
|
payload["content"] = text
|
||||||
payload["ats"] = ats
|
payload["ats"] = ats
|
||||||
has_at = True
|
has_at = True
|
||||||
await self.client.post_text(**payload)
|
await client.post_text(**payload)
|
||||||
|
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
img_url = comp.file
|
img_path = await comp.convert_to_file_path()
|
||||||
img_path = ""
|
|
||||||
if img_url.startswith("file:///"):
|
|
||||||
img_path = img_url[8:]
|
|
||||||
elif comp.file and comp.file.startswith("http"):
|
|
||||||
img_path = await download_image_by_url(comp.file)
|
|
||||||
else:
|
|
||||||
img_path = img_url
|
|
||||||
|
|
||||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
# 检查 record_path 是否在 data/temp 目录中
|
||||||
temp_directory = os.path.abspath("data/temp")
|
temp_directory = os.path.abspath("data/temp")
|
||||||
img_path = os.path.abspath(img_path)
|
|
||||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||||
with open(img_path, "rb") as f:
|
with open(img_path, "rb") as f:
|
||||||
img_path = save_temp_img(f.read())
|
img_path = save_temp_img(f.read())
|
||||||
|
|
||||||
file_id = os.path.basename(img_path)
|
file_id = os.path.basename(img_path)
|
||||||
img_url = f"{self.client.file_server_url}/{file_id}"
|
img_url = f"{client.file_server_url}/{file_id}"
|
||||||
logger.debug(f"gewe callback img url: {img_url}")
|
logger.debug(f"gewe callback img url: {img_url}")
|
||||||
await self.client.post_image(to_wxid, img_url)
|
await client.post_image(to_wxid, img_url)
|
||||||
|
elif isinstance(comp, Video):
|
||||||
|
if comp.cover != "":
|
||||||
|
await client.forward_video(to_wxid, comp.cover)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFmpeg
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
logger.error(
|
||||||
|
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||||
|
)
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_url = comp.file
|
||||||
|
# 根据 url 下载视频
|
||||||
|
video_filename = f"{uuid.uuid4()}.mp4"
|
||||||
|
video_path = f"data/temp/{video_filename}"
|
||||||
|
await download_file(video_url, video_path)
|
||||||
|
|
||||||
|
# 获取视频第一帧
|
||||||
|
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||||
|
try:
|
||||||
|
ff = FFmpeg()
|
||||||
|
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
|
||||||
|
ff.options(command)
|
||||||
|
thumb_file_id = os.path.basename(thumb_path)
|
||||||
|
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取视频第一帧失败: {e}")
|
||||||
|
# 获取视频时长
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFprobe
|
||||||
|
|
||||||
|
# 创建 FFprobe 实例
|
||||||
|
ffprobe = FFprobe(video_url)
|
||||||
|
# 获取时长字符串
|
||||||
|
duration_str = ffprobe.duration
|
||||||
|
# 处理时长字符串
|
||||||
|
video_duration = float(duration_str.replace(":", ""))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取时长失败: {e}")
|
||||||
|
video_duration = 10
|
||||||
|
|
||||||
|
file_id = os.path.basename(video_path)
|
||||||
|
video_url = f"{client.file_server_url}/{file_id}"
|
||||||
|
await client.post_video(
|
||||||
|
to_wxid, video_url, thumb_url, video_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除临时视频和缩略图文件
|
||||||
|
if os.path.exists(video_path):
|
||||||
|
os.remove(video_path)
|
||||||
|
if os.path.exists(thumb_path):
|
||||||
|
os.remove(thumb_path)
|
||||||
elif isinstance(comp, Record):
|
elif isinstance(comp, Record):
|
||||||
# 默认已经存在 data/temp 中
|
# 默认已经存在 data/temp 中
|
||||||
record_url = comp.file
|
record_url = comp.file
|
||||||
record_path = ""
|
record_path = await comp.convert_to_file_path()
|
||||||
|
|
||||||
if record_url.startswith("file:///"):
|
|
||||||
record_path = record_url[8:]
|
|
||||||
elif record_url.startswith("http"):
|
|
||||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
|
||||||
else:
|
|
||||||
record_path = record_url
|
|
||||||
|
|
||||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||||
try:
|
try:
|
||||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
await self.send(
|
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
||||||
MessageChain().message(f"语音文件转换失败。{str(e)}")
|
|
||||||
)
|
|
||||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||||
if duration == 0:
|
if duration == 0:
|
||||||
duration = get_wav_duration(record_path)
|
duration = get_wav_duration(record_path)
|
||||||
file_id = os.path.basename(silk_path)
|
file_id = os.path.basename(silk_path)
|
||||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
record_url = f"{client.file_server_url}/{file_id}"
|
||||||
logger.debug(f"gewe callback record url: {record_url}")
|
logger.debug(f"gewe callback record url: {record_url}")
|
||||||
await self.client.post_voice(to_wxid, record_url, duration * 1000)
|
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||||
elif isinstance(comp, File):
|
elif isinstance(comp, File):
|
||||||
file_path = comp.file
|
file_path = comp.file
|
||||||
file_name = comp.name
|
file_name = comp.name
|
||||||
@@ -131,12 +178,78 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
file_path = file_path
|
file_path = file_path
|
||||||
|
|
||||||
file_id = os.path.basename(file_path)
|
file_id = os.path.basename(file_path)
|
||||||
file_url = f"{self.client.file_server_url}/{file_id}"
|
file_url = f"{client.file_server_url}/{file_id}"
|
||||||
logger.debug(f"gewe callback file url: {file_url}")
|
logger.debug(f"gewe callback file url: {file_url}")
|
||||||
await self.client.post_file(to_wxid, file_url, file_id)
|
await client.post_file(to_wxid, file_url, file_id)
|
||||||
|
elif isinstance(comp, Emoji):
|
||||||
|
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||||
elif isinstance(comp, At):
|
elif isinstance(comp, At):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||||
|
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def get_group(self, group_id=None, **kwargs):
|
||||||
|
# 确定有效的 group_id
|
||||||
|
if group_id is None:
|
||||||
|
group_id = self.get_group_id()
|
||||||
|
|
||||||
|
if not group_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = await self.client.get_group(group_id)
|
||||||
|
data: dict = res["data"]
|
||||||
|
|
||||||
|
if not data["chatroomId"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
members = [
|
||||||
|
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||||
|
for member in data.get("memberList", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return Group(
|
||||||
|
group_id=data["chatroomId"],
|
||||||
|
group_name=data.get("nickName"),
|
||||||
|
group_avatar=data.get("smallHeadImgUrl"),
|
||||||
|
group_owner=data.get("chatRoomOwner"),
|
||||||
|
members=members,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||||
|
):
|
||||||
|
if not use_fallback:
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
buffer += comp.text
|
||||||
|
if any(p in buffer for p in "。?!~…"):
|
||||||
|
buffer = await self.process_buffer(buffer, pattern)
|
||||||
|
else:
|
||||||
|
await self.send(MessageChain(chain=[comp]))
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
|
||||||
|
if buffer.strip():
|
||||||
|
await self.send(MessageChain([Plain(buffer)]))
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ import os
|
|||||||
|
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
from .gewechat_event import GewechatPlatformEvent
|
from .gewechat_event import GewechatPlatformEvent
|
||||||
from .client import SimpleGewechatClient
|
from .client import SimpleGewechatClient
|
||||||
from astrbot.core.message.components import Plain
|
from astrbot import logger
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -45,27 +44,34 @@ class GewechatPlatformAdapter(Platform):
|
|||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
):
|
):
|
||||||
to_wxid = session.session_id
|
session_id = session.session_id
|
||||||
if not to_wxid:
|
if "#" in session_id:
|
||||||
logger.error("无法获取到 to_wxid。")
|
# unique session
|
||||||
return
|
to_wxid = session_id.split("#")[1]
|
||||||
|
else:
|
||||||
|
to_wxid = session_id
|
||||||
|
|
||||||
for comp in message_chain.chain:
|
await GewechatPlatformEvent.send_with_client(
|
||||||
if isinstance(comp, Plain):
|
message_chain, to_wxid, self.client
|
||||||
await self.client.post_text(to_wxid, comp.text)
|
)
|
||||||
|
|
||||||
await super().send_by_session(session, message_chain)
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"gewechat",
|
name="gewechat",
|
||||||
"基于 gewechat 的 Wechat 适配器",
|
description="基于 gewechat 的 Wechat 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
self.client.stop = True
|
self.client.shutdown_event.set()
|
||||||
await asyncio.sleep(1)
|
try:
|
||||||
|
await self.client.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
await self.client.logout()
|
await self.client.logout()
|
||||||
@@ -81,7 +87,7 @@ class GewechatPlatformAdapter(Platform):
|
|||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
if message.type == MessageType.GROUP_MESSAGE:
|
if message.type == MessageType.GROUP_MESSAGE:
|
||||||
if self.settingss["unique_session"]:
|
if self.settingss["unique_session"]:
|
||||||
message.session_id = message.sender.user_id + "_" + message.group_id
|
message.session_id = message.sender.user_id + "#" + message.group_id
|
||||||
|
|
||||||
message_event = GewechatPlatformEvent(
|
message_event = GewechatPlatformEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
|
|||||||
78
astrbot/core/platform/sources/gewechat/xml_data_parser.py
Normal file
78
astrbot/core/platform/sources/gewechat/xml_data_parser.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from defusedxml import ElementTree as eT
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
|
||||||
|
|
||||||
|
|
||||||
|
class GeweDataParser:
|
||||||
|
def __init__(self, data, is_private_chat):
|
||||||
|
self.data = data
|
||||||
|
self.is_private_chat = is_private_chat
|
||||||
|
|
||||||
|
def _format_to_xml(self):
|
||||||
|
return eT.fromstring(self.data)
|
||||||
|
|
||||||
|
def parse_mutil_49(self):
|
||||||
|
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||||
|
if appmsg_type is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
match appmsg_type.text:
|
||||||
|
case "57":
|
||||||
|
return self.parse_reply()
|
||||||
|
|
||||||
|
def parse_emoji(self) -> Emoji | None:
|
||||||
|
try:
|
||||||
|
emoji_element = self._format_to_xml().find(".//emoji")
|
||||||
|
# 提取 md5 和 len 属性
|
||||||
|
if emoji_element is not None:
|
||||||
|
md5_value = emoji_element.get("md5")
|
||||||
|
emoji_size = emoji_element.get("len")
|
||||||
|
cdnurl = emoji_element.get("cdnurl")
|
||||||
|
|
||||||
|
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||||
|
|
||||||
|
def parse_reply(self) -> Reply | None:
|
||||||
|
try:
|
||||||
|
replied_id = -1
|
||||||
|
replied_uid = 0
|
||||||
|
replied_nickname = ""
|
||||||
|
replied_content = ""
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
root = self._format_to_xml()
|
||||||
|
refermsg = root.find(".//refermsg")
|
||||||
|
if refermsg is not None:
|
||||||
|
# 被引用的信息
|
||||||
|
svrid = refermsg.find("svrid")
|
||||||
|
fromusr = refermsg.find("fromusr")
|
||||||
|
displayname = refermsg.find("displayname")
|
||||||
|
refermsg_content = refermsg.find("content")
|
||||||
|
if svrid is not None:
|
||||||
|
replied_id = svrid.text
|
||||||
|
if fromusr is not None:
|
||||||
|
replied_uid = fromusr.text
|
||||||
|
if displayname is not None:
|
||||||
|
replied_nickname = displayname.text
|
||||||
|
if refermsg_content is not None:
|
||||||
|
replied_content = refermsg_content.text
|
||||||
|
|
||||||
|
# 提取引用者说的内容
|
||||||
|
title = root.find(".//appmsg/title")
|
||||||
|
if title is not None:
|
||||||
|
content = title.text
|
||||||
|
|
||||||
|
r = Reply(
|
||||||
|
id=replied_id,
|
||||||
|
chain=[Plain(content)],
|
||||||
|
sender_id=replied_uid,
|
||||||
|
sender_nickname=replied_nickname,
|
||||||
|
sender_str=replied_content,
|
||||||
|
message_str=content,
|
||||||
|
)
|
||||||
|
return r
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||||
@@ -2,6 +2,8 @@ import base64
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
Platform,
|
Platform,
|
||||||
@@ -11,7 +13,6 @@ from astrbot.api.platform import (
|
|||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
)
|
)
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api.message_components import Image, Plain, At
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from .lark_event import LarkMessageEvent
|
from .lark_event import LarkMessageEvent
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
@@ -66,12 +67,47 @@ class LarkPlatformAdapter(Platform):
|
|||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
):
|
):
|
||||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||||
|
wrapped = {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "",
|
||||||
|
"content": res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||||
|
id_type = "chat_id"
|
||||||
|
if "%" in session.session_id:
|
||||||
|
session.session_id = session.session_id.split("%")[1]
|
||||||
|
else:
|
||||||
|
id_type = "open_id"
|
||||||
|
|
||||||
|
request = (
|
||||||
|
CreateMessageRequest.builder()
|
||||||
|
.receive_id_type(id_type)
|
||||||
|
.request_body(
|
||||||
|
CreateMessageRequestBody.builder()
|
||||||
|
.receive_id(session.session_id)
|
||||||
|
.content(json.dumps(wrapped))
|
||||||
|
.msg_type("post")
|
||||||
|
.uuid(str(uuid.uuid4()))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.lark_api.im.v1.message.acreate(request)
|
||||||
|
|
||||||
|
if not response.success():
|
||||||
|
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||||
|
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"lark",
|
name="lark",
|
||||||
"飞书机器人官方 API 适配器",
|
description="飞书机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||||
@@ -92,7 +128,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
at_list = {}
|
at_list = {}
|
||||||
if message.mentions:
|
if message.mentions:
|
||||||
for m in message.mentions:
|
for m in message.mentions:
|
||||||
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
|
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||||
if m.name == self.bot_name:
|
if m.name == self.bot_name:
|
||||||
abm.self_id = m.id.open_id
|
abm.self_id = m.id.open_id
|
||||||
|
|
||||||
@@ -111,7 +147,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
if s in at_list:
|
if s in at_list:
|
||||||
abm.message.append(at_list[s])
|
abm.message.append(at_list[s])
|
||||||
else:
|
else:
|
||||||
abm.message.append(Plain(parts[i].strip()))
|
abm.message.append(Comp.Plain(parts[i].strip()))
|
||||||
elif message.message_type == "post":
|
elif message.message_type == "post":
|
||||||
_ls = []
|
_ls = []
|
||||||
|
|
||||||
@@ -132,7 +168,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
if comp["tag"] == "at":
|
if comp["tag"] == "at":
|
||||||
abm.message.append(at_list[comp["user_id"]])
|
abm.message.append(at_list[comp["user_id"]])
|
||||||
elif comp["tag"] == "text" and comp["text"].strip():
|
elif comp["tag"] == "text" and comp["text"].strip():
|
||||||
abm.message.append(Plain(comp["text"].strip()))
|
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||||
elif comp["tag"] == "img":
|
elif comp["tag"] == "img":
|
||||||
image_key = comp["image_key"]
|
image_key = comp["image_key"]
|
||||||
request = (
|
request = (
|
||||||
@@ -147,10 +183,10 @@ class LarkPlatformAdapter(Platform):
|
|||||||
logger.error(f"无法下载飞书图片: {image_key}")
|
logger.error(f"无法下载飞书图片: {image_key}")
|
||||||
image_bytes = response.file.read()
|
image_bytes = response.file.read()
|
||||||
image_base64 = base64.b64encode(image_bytes).decode()
|
image_base64 = base64.b64encode(image_bytes).decode()
|
||||||
abm.message.append(Image.fromBase64(image_base64))
|
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||||
|
|
||||||
for comp in abm.message:
|
for comp in abm.message:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Comp.Plain):
|
||||||
abm.message_str += comp.text
|
abm.message_str += comp.text
|
||||||
abm.message_id = message.message_id
|
abm.message_id = message.message_id
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
@@ -165,7 +201,10 @@ class LarkPlatformAdapter(Platform):
|
|||||||
else:
|
else:
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
else:
|
else:
|
||||||
abm.session_id = abm.sender.user_id
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
logger.debug(abm)
|
logger.debug(abm)
|
||||||
await self.handle_msg(abm)
|
await self.handle_msg(abm)
|
||||||
@@ -185,5 +224,9 @@ class LarkPlatformAdapter(Platform):
|
|||||||
# self.client.start()
|
# self.client.start()
|
||||||
await self.client._connect()
|
await self.client._connect()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
await self.client._disconnect()
|
||||||
|
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||||
|
|
||||||
def get_client(self) -> lark.Client:
|
def get_client(self) -> lark.Client:
|
||||||
return self.client
|
return self.client
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
import base64
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
|
from io import BytesIO
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||||
@@ -27,22 +29,32 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
||||||
elif isinstance(comp, AstrBotImage):
|
elif isinstance(comp, AstrBotImage):
|
||||||
file_path = ""
|
file_path = ""
|
||||||
|
image_file = None
|
||||||
|
|
||||||
if comp.file and comp.file.startswith("file:///"):
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
file_path = comp.file.replace("file:///", "")
|
file_path = comp.file.replace("file:///", "")
|
||||||
elif comp.file and comp.file.startswith("http"):
|
elif comp.file and comp.file.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(comp.file)
|
image_file_path = await download_image_by_url(comp.file)
|
||||||
file_path = image_file_path
|
file_path = image_file_path
|
||||||
elif comp.file and comp.file.startswith("base64://"):
|
elif comp.file and comp.file.startswith("base64://"):
|
||||||
pass
|
base64_str = comp.file.removeprefix("base64://")
|
||||||
|
image_data = base64.b64decode(base64_str)
|
||||||
|
# save as temp file
|
||||||
|
file_path = f"data/temp/{uuid.uuid4()}_test.jpg"
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(BytesIO(image_data).getvalue())
|
||||||
else:
|
else:
|
||||||
file_path = comp.file
|
file_path = comp.file
|
||||||
|
|
||||||
|
if image_file is None:
|
||||||
|
image_file = open(file_path, "rb")
|
||||||
|
|
||||||
request = (
|
request = (
|
||||||
CreateImageRequest.builder()
|
CreateImageRequest.builder()
|
||||||
.request_body(
|
.request_body(
|
||||||
CreateImageRequestBody.builder()
|
CreateImageRequestBody.builder()
|
||||||
.image_type("message")
|
.image_type("message")
|
||||||
.image(open(file_path, "rb"))
|
.image(image_file)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
.build()
|
.build()
|
||||||
@@ -51,7 +63,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||||
image_key = response.data.image_key
|
image_key = response.data.image_key
|
||||||
print(image_key)
|
logger.debug(image_key)
|
||||||
ret.append(_stage)
|
ret.append(_stage)
|
||||||
ret.append([{"tag": "img", "image_key": image_key}])
|
ret.append([{"tag": "img", "image_key": image_key}])
|
||||||
_stage.clear()
|
_stage.clear()
|
||||||
@@ -91,3 +103,16 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import botpy
|
|||||||
import botpy.message
|
import botpy.message
|
||||||
import botpy.types
|
import botpy.types
|
||||||
import botpy.types.message
|
import botpy.types.message
|
||||||
|
import asyncio
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
@@ -9,6 +10,8 @@ from astrbot.api.message_components import Plain, Image
|
|||||||
from botpy import Client
|
from botpy import Client
|
||||||
from botpy.http import Route
|
from botpy.http import Route
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
|
from botpy.types import message
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||||
@@ -30,8 +33,45 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
self.send_buffer.chain.extend(message.chain)
|
self.send_buffer.chain.extend(message.chain)
|
||||||
|
|
||||||
async def _post_send(self):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
"""QQ 官方 API 仅支持回复一次"""
|
"""流式输出仅支持消息列表私聊"""
|
||||||
|
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||||
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
|
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||||
|
try:
|
||||||
|
async for chain in generator:
|
||||||
|
source = self.message_obj.raw_message
|
||||||
|
if not self.send_buffer:
|
||||||
|
self.send_buffer = chain
|
||||||
|
else:
|
||||||
|
self.send_buffer.chain.extend(chain.chain)
|
||||||
|
|
||||||
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
|
# 真流式传输
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
|
if time_since_last_edit >= throttle_interval:
|
||||||
|
ret = await self._post_send(stream=stream_payload)
|
||||||
|
stream_payload["index"] += 1
|
||||||
|
stream_payload["id"] = ret["id"]
|
||||||
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
|
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||||
|
stream_payload["state"] = 10
|
||||||
|
ret = await self._post_send(stream=stream_payload)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||||
|
self.send_buffer = None
|
||||||
|
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
async def _post_send(self, stream: dict = None):
|
||||||
|
if not self.send_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
source = self.message_obj.raw_message
|
source = self.message_obj.raw_message
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
source,
|
source,
|
||||||
@@ -57,6 +97,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
"msg_id": self.message_obj.message_id,
|
"msg_id": self.message_obj.message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
|
||||||
|
payload["msg_seq"] = random.randint(1, 10000)
|
||||||
|
|
||||||
match type(source):
|
match type(source):
|
||||||
case botpy.message.GroupMessage:
|
case botpy.message.GroupMessage:
|
||||||
if image_base64:
|
if image_base64:
|
||||||
@@ -65,7 +108,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
await self.bot.api.post_group_message(
|
ret = await self.bot.api.post_group_message(
|
||||||
group_openid=source.group_openid, **payload
|
group_openid=source.group_openid, **payload
|
||||||
)
|
)
|
||||||
case botpy.message.C2CMessage:
|
case botpy.message.C2CMessage:
|
||||||
@@ -75,22 +118,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
await self.bot.api.post_c2c_message(
|
if stream:
|
||||||
openid=source.author.user_openid, **payload
|
ret = await self.post_c2c_message(
|
||||||
)
|
openid=source.author.user_openid,
|
||||||
|
**payload,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ret = await self.post_c2c_message(
|
||||||
|
openid=source.author.user_openid, **payload
|
||||||
|
)
|
||||||
|
logger.debug(f"Message sent to C2C: {ret}")
|
||||||
case botpy.message.Message:
|
case botpy.message.Message:
|
||||||
if image_path:
|
if image_path:
|
||||||
payload["file_image"] = image_path
|
payload["file_image"] = image_path
|
||||||
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
|
ret = await self.bot.api.post_message(
|
||||||
|
channel_id=source.channel_id, **payload
|
||||||
|
)
|
||||||
case botpy.message.DirectMessage:
|
case botpy.message.DirectMessage:
|
||||||
if image_path:
|
if image_path:
|
||||||
payload["file_image"] = image_path
|
payload["file_image"] = image_path
|
||||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||||
|
|
||||||
await super().send(self.send_buffer)
|
await super().send(self.send_buffer)
|
||||||
|
|
||||||
self.send_buffer = None
|
self.send_buffer = None
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
async def upload_group_and_c2c_image(
|
async def upload_group_and_c2c_image(
|
||||||
self, image_base64: str, file_type: int, **kwargs
|
self, image_base64: str, file_type: int, **kwargs
|
||||||
) -> botpy.types.message.Media:
|
) -> botpy.types.message.Media:
|
||||||
@@ -112,6 +167,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
|
async def post_c2c_message(
|
||||||
|
self,
|
||||||
|
openid: str,
|
||||||
|
msg_type: int = 0,
|
||||||
|
content: str = None,
|
||||||
|
embed: message.Embed = None,
|
||||||
|
ark: message.Ark = None,
|
||||||
|
message_reference: message.Reference = None,
|
||||||
|
media: message.Media = None,
|
||||||
|
msg_id: str = None,
|
||||||
|
msg_seq: str = 1,
|
||||||
|
event_id: str = None,
|
||||||
|
markdown: message.MarkdownPayload = None,
|
||||||
|
keyboard: message.Keyboard = None,
|
||||||
|
stream: dict = None,
|
||||||
|
) -> message.Message:
|
||||||
|
payload = locals()
|
||||||
|
payload.pop("self", None)
|
||||||
|
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||||
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _parse_to_qqofficial(message: MessageChain):
|
async def _parse_to_qqofficial(message: MessageChain):
|
||||||
plain_text = ""
|
plain_text = ""
|
||||||
@@ -122,16 +198,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
plain_text += i.text
|
plain_text += i.text
|
||||||
elif isinstance(i, Image) and not image_base64:
|
elif isinstance(i, Image) and not image_base64:
|
||||||
if i.file and i.file.startswith("file:///"):
|
if i.file and i.file.startswith("file:///"):
|
||||||
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
|
image_base64 = file_to_base64(i.file[8:])
|
||||||
image_file_path = i.file[8:]
|
image_file_path = i.file[8:]
|
||||||
elif i.file and i.file.startswith("http"):
|
elif i.file and i.file.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(i.file)
|
image_file_path = await download_image_by_url(i.file)
|
||||||
image_base64 = file_to_base64(image_file_path).replace(
|
image_base64 = file_to_base64(image_file_path)
|
||||||
"base64://", ""
|
elif i.file and i.file.startswith("base64://"):
|
||||||
)
|
image_base64 = i.file
|
||||||
else:
|
else:
|
||||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
image_base64 = file_to_base64(i.file)
|
||||||
image_file_path = i.file
|
image_base64 = image_base64.removeprefix("base64://")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"qq_official 忽略 {i.type}")
|
logger.debug(f"qq_official 忽略 {i.type}")
|
||||||
return plain_text, image_base64, image_file_path
|
return plain_text, image_base64, image_file_path
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from astrbot.api.platform import (
|
|||||||
MessageType,
|
MessageType,
|
||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
)
|
)
|
||||||
|
from astrbot import logger
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
from astrbot.api.message_components import Image, Plain, At
|
from astrbot.api.message_components import Image, Plain, At
|
||||||
@@ -125,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"qq_official",
|
name="qq_official",
|
||||||
"QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -204,3 +206,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> botClient:
|
def get_client(self) -> botClient:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
await self.client.close()
|
||||||
|
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .qo_webhook_event import QQOfficialWebhookMessageEvent
|
|||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
from .qo_webhook_server import QQOfficialWebhook
|
from .qo_webhook_server import QQOfficialWebhook
|
||||||
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
# remove logger handler
|
# remove logger handler
|
||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
@@ -98,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"qq_official_webhook",
|
name="qq_official_webhook",
|
||||||
"QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
@@ -111,3 +113,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> botClient:
|
def get_client(self) -> botClient:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.webhook_helper.shutdown_event.set()
|
||||||
|
await self.client.close()
|
||||||
|
try:
|
||||||
|
await self.webhook_helper.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class QQOfficialWebhook:
|
|||||||
self.appid = config["appid"]
|
self.appid = config["appid"]
|
||||||
self.secret = config["secret"]
|
self.secret = config["secret"]
|
||||||
self.port = config.get("port", 6196)
|
self.port = config.get("port", 6196)
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
|
|
||||||
if isinstance(self.port, str):
|
if isinstance(self.port, str):
|
||||||
self.port = int(self.port)
|
self.port = int(self.port)
|
||||||
@@ -29,6 +30,7 @@ class QQOfficialWebhook:
|
|||||||
)
|
)
|
||||||
self.client = botpy_client
|
self.client = botpy_client
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
logger.info("正在登录到 QQ 官方机器人...")
|
logger.info("正在登录到 QQ 官方机器人...")
|
||||||
@@ -95,13 +97,14 @@ class QQOfficialWebhook:
|
|||||||
return {"opcode": 12}
|
return {"opcode": 12}
|
||||||
|
|
||||||
async def start_polling(self):
|
async def start_polling(self):
|
||||||
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。"
|
||||||
|
)
|
||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host="0.0.0.0",
|
host=self.callback_server_host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder,
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
while not self.event_queue.closed: # noqa: ASYNC110
|
await self.shutdown_event.wait()
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("qq_official_webhook 适配器已关闭。")
|
|
||||||
|
|||||||
@@ -1,33 +1,32 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from telegram import BotCommand, Update
|
||||||
|
from telegram.constants import ChatType
|
||||||
|
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
|
||||||
|
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||||
|
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
Platform,
|
|
||||||
AstrBotMessage,
|
AstrBotMessage,
|
||||||
MessageMember,
|
MessageMember,
|
||||||
PlatformMetadata,
|
|
||||||
MessageType,
|
MessageType,
|
||||||
)
|
Platform,
|
||||||
from astrbot.api.event import MessageChain
|
PlatformMetadata,
|
||||||
from astrbot.api.message_components import (
|
register_platform_adapter,
|
||||||
Plain,
|
|
||||||
Image,
|
|
||||||
Record,
|
|
||||||
File as AstrBotFile,
|
|
||||||
Video,
|
|
||||||
At,
|
|
||||||
)
|
)
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from astrbot.api.platform import register_platform_adapter
|
from astrbot.core.star.filter.command import CommandFilter
|
||||||
|
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
|
|
||||||
from telegram import Update
|
|
||||||
from telegram.ext import ApplicationBuilder, ContextTypes, filters
|
|
||||||
from telegram.constants import ChatType
|
|
||||||
from telegram.ext import MessageHandler as TelegramMessageHandler
|
|
||||||
from .tg_event import TelegramPlatformEvent
|
from .tg_event import TelegramPlatformEvent
|
||||||
from astrbot.api import logger
|
|
||||||
from telegram.ext import ExtBot
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -50,18 +49,31 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.telegram.org/bot"
|
base_url = "https://api.telegram.org/bot"
|
||||||
|
|
||||||
|
file_base_url = self.config.get(
|
||||||
|
"telegram_file_base_url", "https://api.telegram.org/file/bot"
|
||||||
|
)
|
||||||
|
if not file_base_url:
|
||||||
|
file_base_url = "https://api.telegram.org/file/bot"
|
||||||
|
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
self.application = (
|
self.application = (
|
||||||
ApplicationBuilder()
|
ApplicationBuilder()
|
||||||
.token(self.config["telegram_token"])
|
.token(self.config["telegram_token"])
|
||||||
.base_url(base_url)
|
.base_url(base_url)
|
||||||
|
.base_file_url(file_base_url)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
message_handler = TelegramMessageHandler(
|
message_handler = TelegramMessageHandler(
|
||||||
filters=filters.ALL, # receive all messages
|
filters=filters.ALL, # receive all messages
|
||||||
callback=self.convert_message,
|
callback=self.message_handler,
|
||||||
)
|
)
|
||||||
self.application.add_handler(message_handler)
|
self.application.add_handler(message_handler)
|
||||||
self.client = self.application.bot
|
self.client = self.application.bot
|
||||||
|
logger.debug(f"Telegram base url: {self.client.base_url}")
|
||||||
|
|
||||||
|
self.scheduler = AsyncIOScheduler()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
@@ -76,87 +88,240 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
@override
|
@override
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"telegram",
|
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||||
"telegram 适配器",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await self.application.initialize()
|
await self.application.initialize()
|
||||||
await self.application.start()
|
await self.application.start()
|
||||||
|
await self.register_commands()
|
||||||
|
|
||||||
|
# TODO 使用更优雅的方式重新注册命令
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self.register_commands,
|
||||||
|
"interval",
|
||||||
|
minutes=5,
|
||||||
|
id="telegram_command_register",
|
||||||
|
misfire_grace_time=60,
|
||||||
|
)
|
||||||
|
self.scheduler.start()
|
||||||
|
|
||||||
queue = self.application.updater.start_polling()
|
queue = self.application.updater.start_polling()
|
||||||
logger.info("Telegram Platform Adapter is running.")
|
logger.info("Telegram Platform Adapter is running.")
|
||||||
await queue
|
await queue
|
||||||
|
|
||||||
|
async def register_commands(self):
|
||||||
|
"""收集所有注册的指令并注册到 Telegram"""
|
||||||
|
try:
|
||||||
|
await self.client.delete_my_commands()
|
||||||
|
commands = self.collect_commands()
|
||||||
|
|
||||||
|
if commands:
|
||||||
|
await self.client.set_my_commands(commands)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
|
||||||
|
|
||||||
|
def collect_commands(self) -> list[BotCommand]:
|
||||||
|
"""从注册的处理器中收集所有指令"""
|
||||||
|
command_dict = {}
|
||||||
|
skip_commands = {"start"}
|
||||||
|
|
||||||
|
for handler_md in star_handlers_registry._handlers:
|
||||||
|
handler_metadata = handler_md[1]
|
||||||
|
if not star_map[handler_metadata.handler_module_path].activated:
|
||||||
|
continue
|
||||||
|
for event_filter in handler_metadata.event_filters:
|
||||||
|
cmd_info = self._extract_command_info(
|
||||||
|
event_filter, handler_metadata, skip_commands
|
||||||
|
)
|
||||||
|
if cmd_info:
|
||||||
|
cmd_name, description = cmd_info
|
||||||
|
command_dict.setdefault(cmd_name, description)
|
||||||
|
|
||||||
|
commands_a = sorted(command_dict.keys())
|
||||||
|
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_command_info(
|
||||||
|
event_filter, handler_metadata, skip_commands: set
|
||||||
|
) -> tuple[str, str] | None:
|
||||||
|
"""从事件过滤器中提取指令信息"""
|
||||||
|
cmd_name = None
|
||||||
|
is_group = False
|
||||||
|
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
|
||||||
|
if (
|
||||||
|
event_filter.parent_command_names
|
||||||
|
and event_filter.parent_command_names != [""]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
cmd_name = event_filter.command_name
|
||||||
|
elif isinstance(event_filter, CommandGroupFilter):
|
||||||
|
if event_filter.parent_group:
|
||||||
|
return None
|
||||||
|
cmd_name = event_filter.group_name
|
||||||
|
is_group = True
|
||||||
|
|
||||||
|
if not cmd_name or cmd_name in skip_commands:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
||||||
|
logger.debug(f"跳过无法注册的命令: {cmd_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build description.
|
||||||
|
description = handler_metadata.desc or (
|
||||||
|
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
|
||||||
|
)
|
||||||
|
if len(description) > 30:
|
||||||
|
description = description[:30] + "..."
|
||||||
|
return cmd_name, description
|
||||||
|
|
||||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
await context.bot.send_message(
|
await context.bot.send_message(
|
||||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
logger.debug(f"Telegram message: {update.message}")
|
||||||
|
abm = await self.convert_message(update, context)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
async def convert_message(
|
async def convert_message(
|
||||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE
|
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||||
) -> AstrBotMessage:
|
) -> AstrBotMessage:
|
||||||
|
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||||
|
|
||||||
|
@param update: Telegram 的 Update 对象。
|
||||||
|
@param context: Telegram 的 Context 对象。
|
||||||
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
|
"""
|
||||||
message = AstrBotMessage()
|
message = AstrBotMessage()
|
||||||
|
message.session_id = str(update.message.chat.id)
|
||||||
# 获得是群聊还是私聊
|
# 获得是群聊还是私聊
|
||||||
if update.effective_chat.type == ChatType.PRIVATE:
|
if update.message.chat.type == ChatType.PRIVATE:
|
||||||
message.type = MessageType.FRIEND_MESSAGE
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
else:
|
else:
|
||||||
message.type = MessageType.GROUP_MESSAGE
|
message.type = MessageType.GROUP_MESSAGE
|
||||||
message.group_id = update.effective_chat.id
|
message.group_id = str(update.message.chat.id)
|
||||||
|
if update.message.message_thread_id:
|
||||||
|
# Topic Group
|
||||||
|
message.group_id += "#" + str(update.message.message_thread_id)
|
||||||
|
message.session_id = message.group_id
|
||||||
|
|
||||||
message.message_id = str(update.message.message_id)
|
message.message_id = str(update.message.message_id)
|
||||||
message.session_id = str(update.effective_chat.id)
|
|
||||||
message.sender = MessageMember(
|
message.sender = MessageMember(
|
||||||
str(update.effective_user.id), update.effective_user.username
|
str(update.message.from_user.id), update.message.from_user.username
|
||||||
)
|
)
|
||||||
message.self_id = str(context.bot.username)
|
message.self_id = str(context.bot.username)
|
||||||
message.raw_message = update
|
message.raw_message = update
|
||||||
message.message_str = ""
|
message.message_str = ""
|
||||||
message.message = []
|
message.message = []
|
||||||
|
|
||||||
logger.debug(f"Telegram message: {update.message}")
|
if update.message.reply_to_message and not (
|
||||||
|
update.message.is_topic_message
|
||||||
|
and update.message.message_thread_id
|
||||||
|
== update.message.reply_to_message.message_id
|
||||||
|
):
|
||||||
|
# 获取回复消息
|
||||||
|
reply_update = Update(
|
||||||
|
update_id=1,
|
||||||
|
message=update.message.reply_to_message,
|
||||||
|
)
|
||||||
|
reply_abm = await self.convert_message(reply_update, context, False)
|
||||||
|
|
||||||
|
message.message.append(
|
||||||
|
Comp.Reply(
|
||||||
|
id=reply_abm.message_id,
|
||||||
|
chain=reply_abm.message,
|
||||||
|
sender_id=reply_abm.sender.user_id,
|
||||||
|
sender_nickname=reply_abm.sender.nickname,
|
||||||
|
time=reply_abm.timestamp,
|
||||||
|
message_str=reply_abm.message_str,
|
||||||
|
text=reply_abm.message_str,
|
||||||
|
qq=reply_abm.sender.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if update.message.text:
|
if update.message.text:
|
||||||
|
# 处理文本消息
|
||||||
plain_text = update.message.text
|
plain_text = update.message.text
|
||||||
|
|
||||||
|
# 群聊场景命令特殊处理
|
||||||
|
if plain_text.startswith("/"):
|
||||||
|
command_parts = plain_text.split(" ", 1)
|
||||||
|
if "@" in command_parts[0]:
|
||||||
|
command, bot_name = command_parts[0].split("@")
|
||||||
|
if bot_name == self.client.username:
|
||||||
|
plain_text = command + (
|
||||||
|
f" {command_parts[1]}" if len(command_parts) > 1 else ""
|
||||||
|
)
|
||||||
|
|
||||||
if update.message.entities:
|
if update.message.entities:
|
||||||
for entity in update.message.entities:
|
for entity in update.message.entities:
|
||||||
if entity.type == "mention":
|
if entity.type == "mention":
|
||||||
name = plain_text[entity.offset+1 : entity.offset + entity.length]
|
name = plain_text[
|
||||||
message.message.append(At(qq=name, name=name))
|
entity.offset + 1 : entity.offset + entity.length
|
||||||
|
]
|
||||||
|
message.message.append(Comp.At(qq=name, name=name))
|
||||||
plain_text = (
|
plain_text = (
|
||||||
plain_text[: entity.offset]
|
plain_text[: entity.offset]
|
||||||
+ plain_text[entity.offset + entity.length :]
|
+ plain_text[entity.offset + entity.length :]
|
||||||
)
|
)
|
||||||
|
|
||||||
message.message.append(Plain(plain_text))
|
if plain_text:
|
||||||
|
message.message.append(Comp.Plain(plain_text))
|
||||||
message.message_str = plain_text
|
message.message_str = plain_text
|
||||||
|
|
||||||
|
if message.message_str.strip() == "/start":
|
||||||
|
await self.start(update, context)
|
||||||
|
return
|
||||||
|
|
||||||
elif update.message.voice:
|
elif update.message.voice:
|
||||||
file = await update.message.voice.get_file()
|
file = await update.message.voice.get_file()
|
||||||
message.message = [
|
message.message = [
|
||||||
Record(file=file.file_path, url=file.file_path),
|
Comp.Record(file=file.file_path, url=file.file_path),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif update.message.photo:
|
elif update.message.photo:
|
||||||
photo = update.message.photo[-1] # get the largest photo
|
photo = update.message.photo[-1] # get the largest photo
|
||||||
file = await photo.get_file()
|
file = await photo.get_file()
|
||||||
message.message.append(Image(file=file.file_path, url=file.file_path))
|
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||||
|
if update.message.caption:
|
||||||
|
message.message_str = update.message.caption
|
||||||
|
message.message.append(Comp.Plain(message.message_str))
|
||||||
|
if update.message.caption_entities:
|
||||||
|
for entity in update.message.caption_entities:
|
||||||
|
if entity.type == "mention":
|
||||||
|
name = message.message_str[
|
||||||
|
entity.offset + 1 : entity.offset + entity.length
|
||||||
|
]
|
||||||
|
message.message.append(Comp.At(qq=name, name=name))
|
||||||
|
|
||||||
|
elif update.message.sticker:
|
||||||
|
# 将sticker当作图片处理
|
||||||
|
file = await update.message.sticker.get_file()
|
||||||
|
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||||
|
if update.message.sticker.emoji:
|
||||||
|
sticker_text = f"Sticker: {update.message.sticker.emoji}"
|
||||||
|
message.message_str = sticker_text
|
||||||
|
message.message.append(Comp.Plain(sticker_text))
|
||||||
|
|
||||||
elif update.message.document:
|
elif update.message.document:
|
||||||
file = await update.message.document.get_file()
|
file = await update.message.document.get_file()
|
||||||
message.message = [
|
message.message = [
|
||||||
AstrBotFile(
|
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||||
file=file.file_path, name=update.message.document.file_name
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
elif update.message.video:
|
elif update.message.video:
|
||||||
file = await update.message.video.get_file()
|
file = await update.message.video.get_file()
|
||||||
message.message = [
|
message.message = [
|
||||||
Video(file=file.file_path, path=file.file_path),
|
Comp.Video(file=file.file_path, path=file.file_path),
|
||||||
]
|
]
|
||||||
|
|
||||||
await self.handle_msg(message)
|
return message
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
message_event = TelegramPlatformEvent(
|
message_event = TelegramPlatformEvent(
|
||||||
@@ -170,3 +335,19 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> ExtBot:
|
def get_client(self) -> ExtBot:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
try:
|
||||||
|
if self.scheduler.running:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
|
||||||
|
await self.application.stop()
|
||||||
|
await self.client.delete_my_commands()
|
||||||
|
|
||||||
|
# 保险起见先判断是否存在updater对象
|
||||||
|
if self.application.updater is not None:
|
||||||
|
await self.application.updater.stop()
|
||||||
|
|
||||||
|
logger.info("Telegram 适配器已被优雅地关闭")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Telegram 适配器关闭时出错: {e}")
|
||||||
|
|||||||
@@ -1,7 +1,18 @@
|
|||||||
|
import asyncio
|
||||||
|
import telegramify_markdown
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||||
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
Reply,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Record,
|
||||||
|
)
|
||||||
from telegram.ext import ExtBot
|
from telegram.ext import ExtBot
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
class TelegramPlatformEvent(AstrMessageEvent):
|
class TelegramPlatformEvent(AstrMessageEvent):
|
||||||
@@ -31,36 +42,47 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
at_user_id = i.name
|
at_user_id = i.name
|
||||||
|
|
||||||
at_flag = False
|
at_flag = False
|
||||||
|
message_thread_id = None
|
||||||
|
if "#" in user_name:
|
||||||
|
# it's a supergroup chat with message_thread_id
|
||||||
|
user_name, message_thread_id = user_name.split("#")
|
||||||
for i in message.chain:
|
for i in message.chain:
|
||||||
payload = {
|
payload = {
|
||||||
"chat_id": user_name,
|
"chat_id": user_name,
|
||||||
}
|
}
|
||||||
if has_reply:
|
if has_reply:
|
||||||
payload["reply_to_message_id"] = reply_message_id
|
payload["reply_to_message_id"] = reply_message_id
|
||||||
|
if message_thread_id:
|
||||||
|
payload["message_thread_id"] = message_thread_id
|
||||||
|
|
||||||
if isinstance(i, Plain):
|
if isinstance(i, Plain):
|
||||||
if at_user_id and not at_flag:
|
if at_user_id and not at_flag:
|
||||||
i.text = f"@{at_user_id} " + i.text
|
i.text = f"@{at_user_id} " + i.text
|
||||||
at_flag = True
|
at_flag = True
|
||||||
await client.send_message(text=i.text, **payload)
|
text = i.text
|
||||||
|
try:
|
||||||
|
text = telegramify_markdown.markdownify(
|
||||||
|
i.text, max_line_length=None, normalize_whitespace=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||||
elif isinstance(i, Image):
|
elif isinstance(i, Image):
|
||||||
if i.path:
|
image_path = await i.convert_to_file_path()
|
||||||
image_path = i.path
|
await client.send_photo(photo=image_path, **payload)
|
||||||
else:
|
|
||||||
image_path = i.file
|
|
||||||
|
|
||||||
if image_path.startswith("base64://"):
|
|
||||||
import base64
|
|
||||||
|
|
||||||
base64_data = image_path[9:]
|
|
||||||
image_bytes = base64.b64decode(base64_data)
|
|
||||||
await client.send_photo(photo=image_bytes, **payload)
|
|
||||||
else:
|
|
||||||
await client.send_photo(photo=image_path, **payload)
|
|
||||||
elif isinstance(i, File):
|
elif isinstance(i, File):
|
||||||
|
if i.file.startswith("https://"):
|
||||||
|
path = "data/temp/" + i.name
|
||||||
|
await download_file(i.file, path)
|
||||||
|
i.file = path
|
||||||
|
|
||||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
await client.send_voice(voice=i.file, **payload)
|
path = await i.convert_to_file_path()
|
||||||
|
await client.send_voice(voice=path, **payload)
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
@@ -68,3 +90,107 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
message_thread_id = None
|
||||||
|
|
||||||
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
|
user_name = self.message_obj.group_id
|
||||||
|
else:
|
||||||
|
user_name = self.get_sender_id()
|
||||||
|
|
||||||
|
if "#" in user_name:
|
||||||
|
# it's a supergroup chat with message_thread_id
|
||||||
|
user_name, message_thread_id = user_name.split("#")
|
||||||
|
payload = {
|
||||||
|
"chat_id": user_name,
|
||||||
|
}
|
||||||
|
if message_thread_id:
|
||||||
|
payload["reply_to_message_id"] = message_thread_id
|
||||||
|
|
||||||
|
delta = ""
|
||||||
|
current_content = ""
|
||||||
|
message_id = None
|
||||||
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
|
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
# 处理消息链中的每个组件
|
||||||
|
for i in chain.chain:
|
||||||
|
if isinstance(i, Plain):
|
||||||
|
delta += i.text
|
||||||
|
elif isinstance(i, Image):
|
||||||
|
image_path = await i.convert_to_file_path()
|
||||||
|
await self.client.send_photo(photo=image_path, **payload)
|
||||||
|
continue
|
||||||
|
elif isinstance(i, File):
|
||||||
|
if i.file.startswith("https://"):
|
||||||
|
path = "data/temp/" + i.name
|
||||||
|
await download_file(i.file, path)
|
||||||
|
i.file = path
|
||||||
|
|
||||||
|
await self.client.send_document(
|
||||||
|
document=i.file, filename=i.name, **payload
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif isinstance(i, Record):
|
||||||
|
path = await i.convert_to_file_path()
|
||||||
|
await self.client.send_voice(voice=path, **payload)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Plain
|
||||||
|
if not message_id:
|
||||||
|
try:
|
||||||
|
msg = await self.client.send_message(text=delta, **payload)
|
||||||
|
current_content = delta
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
|
message_id = msg.message_id
|
||||||
|
last_edit_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
) # 记录初始消息发送时间
|
||||||
|
else:
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
|
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
|
||||||
|
if time_since_last_edit >= throttle_interval:
|
||||||
|
# 编辑消息
|
||||||
|
try:
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
current_content = delta
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
|
last_edit_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
) # 更新上次编辑的时间
|
||||||
|
|
||||||
|
try:
|
||||||
|
if delta and current_content != delta:
|
||||||
|
try:
|
||||||
|
markdown_text = telegramify_markdown.markdownify(
|
||||||
|
delta, max_line_length=None, normalize_whitespace=False
|
||||||
|
)
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=markdown_text,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
parse_mode="MarkdownV2",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
|
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from astrbot.core.platform import (
|
|||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
|
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
from astrbot.core import web_chat_queue
|
||||||
from .webchat_event import WebChatMessageEvent
|
from .webchat_event import WebChatMessageEvent
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
@@ -43,21 +43,13 @@ class WebChatAdapter(Platform):
|
|||||||
self.imgs_dir = "data/webchat/imgs"
|
self.imgs_dir = "data/webchat/imgs"
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
"webchat",
|
name="webchat", description="webchat", id=self.config.get("id")
|
||||||
"webchat",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
):
|
):
|
||||||
# abm.session_id = f"webchat!{username}!{cid}"
|
await WebChatMessageEvent._send(message_chain, session.session_id)
|
||||||
plain = ""
|
|
||||||
cid = session.session_id.split("!")[-1]
|
|
||||||
for comp in message_chain.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
plain += comp.text
|
|
||||||
web_chat_back_queue.put_nowait((plain, cid))
|
|
||||||
|
|
||||||
await super().send_by_session(session, message_chain)
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||||
@@ -126,3 +118,7 @@ class WebChatAdapter(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
# Do nothing
|
||||||
|
pass
|
||||||
|
|||||||
@@ -3,31 +3,43 @@ import uuid
|
|||||||
import base64
|
import base64
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core import web_chat_back_queue
|
from astrbot.core import web_chat_back_queue
|
||||||
|
|
||||||
|
imgs_dir = "data/webchat/imgs"
|
||||||
|
|
||||||
|
|
||||||
class WebChatMessageEvent(AstrMessageEvent):
|
class WebChatMessageEvent(AstrMessageEvent):
|
||||||
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.imgs_dir = "data/webchat/imgs"
|
os.makedirs(imgs_dir, exist_ok=True)
|
||||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
@staticmethod
|
||||||
|
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||||
if not message:
|
if not message:
|
||||||
web_chat_back_queue.put_nowait(None)
|
await web_chat_back_queue.put(
|
||||||
return
|
{"type": "end", "data": "", "streaming": False}
|
||||||
|
)
|
||||||
cid = self.session_id.split("!")[-1]
|
return ""
|
||||||
|
|
||||||
|
cid = session_id.split("!")[-1]
|
||||||
|
data = ""
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
data = comp.text
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "plain",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
# save image to local
|
# save image to local
|
||||||
filename = str(uuid.uuid4()) + ".jpg"
|
filename = str(uuid.uuid4()) + ".jpg"
|
||||||
path = os.path.join(self.imgs_dir, filename)
|
path = os.path.join(imgs_dir, filename)
|
||||||
if comp.file and comp.file.startswith("file:///"):
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
ph = comp.file[8:]
|
ph = comp.file[8:]
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
@@ -44,8 +56,69 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
with open(comp.file, "rb") as f2:
|
with open(comp.file, "rb") as f2:
|
||||||
f.write(f2.read())
|
f.write(f2.read())
|
||||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
data = f"[IMAGE]{filename}"
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Record):
|
||||||
|
# save record to local
|
||||||
|
filename = str(uuid.uuid4()) + ".wav"
|
||||||
|
path = os.path.join(imgs_dir, filename)
|
||||||
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
|
ph = comp.file[8:]
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
with open(ph, "rb") as f2:
|
||||||
|
f.write(f2.read())
|
||||||
|
elif comp.file and comp.file.startswith("http"):
|
||||||
|
await download_image_by_url(comp.file, path=path)
|
||||||
|
else:
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
with open(comp.file, "rb") as f2:
|
||||||
|
f.write(f2.read())
|
||||||
|
data = f"[RECORD]{filename}"
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "record",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"webchat 忽略: {comp.type}")
|
logger.debug(f"webchat 忽略: {comp.type}")
|
||||||
web_chat_back_queue.put_nowait(None)
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": "",
|
||||||
|
"streaming": False,
|
||||||
|
"cid": self.session_id.split("!")[-1],
|
||||||
|
}
|
||||||
|
)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
final_data = ""
|
||||||
|
async for chain in generator:
|
||||||
|
final_data += await WebChatMessageEvent._send(
|
||||||
|
chain, session_id=self.session_id, streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": final_data,
|
||||||
|
"streaming": True,
|
||||||
|
"cid": self.session_id.split("!")[-1],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class WecomServer:
|
|||||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.port = int(config.get("port"))
|
self.port = int(config.get("port"))
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/callback/command", view_func=self.verify, methods=["GET"]
|
"/callback/command", view_func=self.verify, methods=["GET"]
|
||||||
)
|
)
|
||||||
@@ -49,6 +50,7 @@ class WecomServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.callback = None
|
self.callback = None
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def verify(self):
|
async def verify(self):
|
||||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||||
@@ -86,17 +88,17 @@ class WecomServer:
|
|||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
async def start_polling(self):
|
async def start_polling(self):
|
||||||
logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。")
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。"
|
||||||
|
)
|
||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host="0.0.0.0",
|
host=self.callback_server_host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder,
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
while not self.event_queue.closed: # noqa: ASYNC110
|
await self.shutdown_event.wait()
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("企业微信 适配器已关闭。")
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("wecom", "wecom 适配器")
|
@register_platform_adapter("wecom", "wecom 适配器")
|
||||||
@@ -232,3 +234,11 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> WeChatClient:
|
def get_client(self) -> WeChatClient:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.server.shutdown_event.set()
|
||||||
|
try:
|
||||||
|
await self.server.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("企业微信 适配器已被优雅地关闭")
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import asyncio
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
from astrbot.api.message_components import Plain, Image, Record
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
from wechatpy.enterprise import WeChatClient
|
from wechatpy.enterprise import WeChatClient
|
||||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
|
||||||
|
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
|
|
||||||
@@ -34,23 +34,56 @@ class WecomPlatformEvent(AstrMessageEvent):
|
|||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def split_plain(self, plain: str) -> list[str]:
|
||||||
|
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plain (str): 要分割的长文本
|
||||||
|
Returns:
|
||||||
|
list[str]: 分割后的文本列表
|
||||||
|
"""
|
||||||
|
if len(plain) <= 2048:
|
||||||
|
return [plain]
|
||||||
|
else:
|
||||||
|
result = []
|
||||||
|
start = 0
|
||||||
|
while start < len(plain):
|
||||||
|
# 剩下的字符串长度<2048时结束
|
||||||
|
if start + 2048 >= len(plain):
|
||||||
|
result.append(plain[start:])
|
||||||
|
break
|
||||||
|
|
||||||
|
# 向前搜索分割标点符号
|
||||||
|
end = min(start + 2048, len(plain))
|
||||||
|
cut_position = end
|
||||||
|
for i in range(end, start, -1):
|
||||||
|
if i < len(plain) and plain[i-1] in ["。", "!", "?", ".", "!", "?", "\n", ";", ";"]:
|
||||||
|
cut_position = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# 没找到合适的位置分割, 直接切分
|
||||||
|
if cut_position == end and end < len(plain):
|
||||||
|
cut_position = end
|
||||||
|
|
||||||
|
result.append(plain[start:cut_position])
|
||||||
|
start = cut_position
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
message_obj = self.message_obj
|
message_obj = self.message_obj
|
||||||
|
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
self.client.message.send_text(
|
# Split long text messages if needed
|
||||||
message_obj.self_id, message_obj.session_id, comp.text
|
plain_chunks = await self.split_plain(comp.text)
|
||||||
)
|
for chunk in plain_chunks:
|
||||||
|
self.client.message.send_text(
|
||||||
|
message_obj.self_id, message_obj.session_id, chunk
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
img_url = comp.file
|
img_path = await comp.convert_to_file_path()
|
||||||
img_path = ""
|
|
||||||
if img_url.startswith("file:///"):
|
|
||||||
img_path = img_url[8:]
|
|
||||||
elif comp.file and comp.file.startswith("http"):
|
|
||||||
img_path = await download_image_by_url(comp.file)
|
|
||||||
else:
|
|
||||||
img_path = img_url
|
|
||||||
|
|
||||||
with open(img_path, "rb") as f:
|
with open(img_path, "rb") as f:
|
||||||
try:
|
try:
|
||||||
@@ -68,16 +101,7 @@ class WecomPlatformEvent(AstrMessageEvent):
|
|||||||
response["media_id"],
|
response["media_id"],
|
||||||
)
|
)
|
||||||
elif isinstance(comp, Record):
|
elif isinstance(comp, Record):
|
||||||
record_url = comp.file
|
record_path = await comp.convert_to_file_path()
|
||||||
record_path = ""
|
|
||||||
|
|
||||||
if record_url.startswith("file:///"):
|
|
||||||
record_path = record_url[8:]
|
|
||||||
elif record_url.startswith("http"):
|
|
||||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
|
||||||
else:
|
|
||||||
record_path = record_url
|
|
||||||
|
|
||||||
# 转成amr
|
# 转成amr
|
||||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||||
pydub.AudioSegment.from_wav(record_path).export(
|
pydub.AudioSegment.from_wav(record_path).export(
|
||||||
@@ -101,3 +125,16 @@ class WecomPlatformEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from .provider import Provider, Personality, STTProvider
|
from .provider import Provider, Personality, STTProvider
|
||||||
|
|
||||||
from .entites import ProviderMetaData
|
from .entities import ProviderMetaData
|
||||||
|
|
||||||
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
|
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
|
||||||
|
|||||||
@@ -1,67 +1,19 @@
|
|||||||
import enum
|
from astrbot.core.provider.entities import (
|
||||||
from dataclasses import dataclass, field
|
ProviderRequest,
|
||||||
from typing import List, Dict, Type
|
ProviderType,
|
||||||
from .func_tool_manager import FuncCall
|
ProviderMetaData,
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
ToolCallsResult,
|
||||||
from astrbot.core.db.po import Conversation
|
AssistantMessageSegment,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
LLMResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
class ProviderType(enum.Enum):
|
"ProviderRequest",
|
||||||
CHAT_COMPLETION = "chat_completion"
|
"ProviderType",
|
||||||
SPEECH_TO_TEXT = "speech_to_text"
|
"ProviderMetaData",
|
||||||
TEXT_TO_SPEECH = "text_to_speech"
|
"ToolCallsResult",
|
||||||
|
"AssistantMessageSegment",
|
||||||
|
"ToolCallMessageSegment",
|
||||||
@dataclass
|
"LLMResponse",
|
||||||
class ProviderMetaData:
|
]
|
||||||
type: str
|
|
||||||
"""提供商适配器名称,如 openai, ollama"""
|
|
||||||
desc: str = ""
|
|
||||||
"""提供商适配器描述."""
|
|
||||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
|
||||||
cls_type: Type = None
|
|
||||||
|
|
||||||
default_config_tmpl: dict = None
|
|
||||||
"""平台的默认配置模板"""
|
|
||||||
provider_display_name: str = None
|
|
||||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProviderRequest:
|
|
||||||
prompt: str
|
|
||||||
"""提示词"""
|
|
||||||
session_id: str = ""
|
|
||||||
"""会话 ID"""
|
|
||||||
image_urls: List[str] = None
|
|
||||||
"""图片 URL 列表"""
|
|
||||||
func_tool: FuncCall = None
|
|
||||||
"""工具"""
|
|
||||||
contexts: List = None
|
|
||||||
"""上下文。格式与 openai 的上下文格式一致:
|
|
||||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
|
||||||
"""
|
|
||||||
system_prompt: str = ""
|
|
||||||
"""系统提示词"""
|
|
||||||
conversation: Conversation = None
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMResponse:
|
|
||||||
role: str
|
|
||||||
"""角色, assistant, tool, err"""
|
|
||||||
completion_text: str = ""
|
|
||||||
"""LLM 返回的文本"""
|
|
||||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
|
||||||
"""工具调用参数"""
|
|
||||||
tools_call_name: List[str] = field(default_factory=list)
|
|
||||||
"""工具调用名称"""
|
|
||||||
|
|
||||||
raw_completion: ChatCompletion = None
|
|
||||||
_new_record: Dict[str, any] = None
|
|
||||||
|
|||||||
281
astrbot/core/provider/entities.py
Normal file
281
astrbot/core/provider/entities.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
import enum
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from astrbot import logger
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Type
|
||||||
|
from .func_tool_manager import FuncCall
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from astrbot.core.db.po import Conversation
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(enum.Enum):
|
||||||
|
CHAT_COMPLETION = "chat_completion"
|
||||||
|
SPEECH_TO_TEXT = "speech_to_text"
|
||||||
|
TEXT_TO_SPEECH = "text_to_speech"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderMetaData:
|
||||||
|
type: str
|
||||||
|
"""提供商适配器名称,如 openai, ollama"""
|
||||||
|
desc: str = ""
|
||||||
|
"""提供商适配器描述."""
|
||||||
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||||
|
cls_type: Type = None
|
||||||
|
|
||||||
|
default_config_tmpl: dict = None
|
||||||
|
"""平台的默认配置模板"""
|
||||||
|
provider_display_name: str = None
|
||||||
|
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallMessageSegment:
|
||||||
|
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
|
tool_call_id: str
|
||||||
|
content: str
|
||||||
|
role: str = "tool"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"tool_call_id": self.tool_call_id,
|
||||||
|
"content": self.content,
|
||||||
|
"role": self.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistantMessageSegment:
|
||||||
|
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
|
content: str = None
|
||||||
|
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||||
|
role: str = "assistant"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
ret = {
|
||||||
|
"role": self.role,
|
||||||
|
}
|
||||||
|
if self.content:
|
||||||
|
ret["content"] = self.content
|
||||||
|
elif self.tool_calls:
|
||||||
|
ret["tool_calls"] = self.tool_calls
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallsResult:
|
||||||
|
"""工具调用结果"""
|
||||||
|
|
||||||
|
tool_calls_info: AssistantMessageSegment
|
||||||
|
"""函数调用的信息"""
|
||||||
|
tool_calls_result: List[ToolCallMessageSegment]
|
||||||
|
"""函数调用的结果"""
|
||||||
|
|
||||||
|
def to_openai_messages(self) -> List[Dict]:
|
||||||
|
ret = [
|
||||||
|
self.tool_calls_info.to_dict(),
|
||||||
|
*[item.to_dict() for item in self.tool_calls_result],
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderRequest:
|
||||||
|
prompt: str
|
||||||
|
"""提示词"""
|
||||||
|
session_id: str = ""
|
||||||
|
"""会话 ID"""
|
||||||
|
image_urls: List[str] = None
|
||||||
|
"""图片 URL 列表"""
|
||||||
|
func_tool: FuncCall = None
|
||||||
|
"""可用的函数工具"""
|
||||||
|
contexts: List = None
|
||||||
|
"""上下文。格式与 openai 的上下文格式一致:
|
||||||
|
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||||
|
"""
|
||||||
|
system_prompt: str = ""
|
||||||
|
"""系统提示词"""
|
||||||
|
conversation: Conversation = None
|
||||||
|
|
||||||
|
tool_calls_result: ToolCallsResult = None
|
||||||
|
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def _print_friendly_context(self):
|
||||||
|
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||||
|
if not self.contexts:
|
||||||
|
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
|
||||||
|
|
||||||
|
result_parts = []
|
||||||
|
|
||||||
|
for ctx in self.contexts:
|
||||||
|
role = ctx.get("role", "unknown")
|
||||||
|
content = ctx.get("content", "")
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
result_parts.append(f"{role}: {content}")
|
||||||
|
elif isinstance(content, list):
|
||||||
|
msg_parts = []
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
item_type = item.get("type", "")
|
||||||
|
|
||||||
|
if item_type == "text":
|
||||||
|
msg_parts.append(item.get("text", ""))
|
||||||
|
elif item_type == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
|
||||||
|
if image_count > 0:
|
||||||
|
if msg_parts:
|
||||||
|
msg_parts.append(f"[+{image_count} images]")
|
||||||
|
else:
|
||||||
|
msg_parts.append(f"[{image_count} images]")
|
||||||
|
|
||||||
|
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
||||||
|
|
||||||
|
return result_parts
|
||||||
|
|
||||||
|
async def assemble_context(self) -> Dict:
|
||||||
|
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||||
|
if self.image_urls:
|
||||||
|
user_content = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
|
||||||
|
}
|
||||||
|
for image_url in self.image_urls:
|
||||||
|
if image_url.startswith("http"):
|
||||||
|
image_path = await download_image_by_url(image_url)
|
||||||
|
image_data = await self._encode_image_bs64(image_path)
|
||||||
|
elif image_url.startswith("file:///"):
|
||||||
|
image_path = image_url.replace("file:///", "")
|
||||||
|
image_data = await self._encode_image_bs64(image_path)
|
||||||
|
else:
|
||||||
|
image_data = await self._encode_image_bs64(image_url)
|
||||||
|
if not image_data:
|
||||||
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
|
continue
|
||||||
|
user_content["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_data}}
|
||||||
|
)
|
||||||
|
return user_content
|
||||||
|
else:
|
||||||
|
return {"role": "user", "content": self.prompt}
|
||||||
|
|
||||||
|
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||||
|
"""将图片转换为 base64"""
|
||||||
|
if image_url.startswith("base64://"):
|
||||||
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||||
|
with open(image_url, "rb") as f:
|
||||||
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
return "data:image/jpeg;base64," + image_bs64
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
role: str
|
||||||
|
"""角色, assistant, tool, err"""
|
||||||
|
result_chain: MessageChain = None
|
||||||
|
"""返回的消息链"""
|
||||||
|
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||||
|
"""工具调用参数"""
|
||||||
|
tools_call_name: List[str] = field(default_factory=list)
|
||||||
|
"""工具调用名称"""
|
||||||
|
tools_call_ids: List[str] = field(default_factory=list)
|
||||||
|
"""工具调用 ID"""
|
||||||
|
|
||||||
|
raw_completion: ChatCompletion = None
|
||||||
|
_new_record: Dict[str, any] = None
|
||||||
|
|
||||||
|
_completion_text: str = ""
|
||||||
|
|
||||||
|
is_chunk: bool = False
|
||||||
|
"""是否是流式输出的单个 Chunk"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: str,
|
||||||
|
completion_text: str = "",
|
||||||
|
result_chain: MessageChain = None,
|
||||||
|
tools_call_args: List[Dict[str, any]] = None,
|
||||||
|
tools_call_name: List[str] = None,
|
||||||
|
tools_call_ids: List[str] = None,
|
||||||
|
raw_completion: ChatCompletion = None,
|
||||||
|
_new_record: Dict[str, any] = None,
|
||||||
|
is_chunk: bool = False,
|
||||||
|
):
|
||||||
|
"""初始化 LLMResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): 角色, assistant, tool, err
|
||||||
|
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
|
||||||
|
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
|
||||||
|
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
|
||||||
|
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
||||||
|
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||||
|
"""
|
||||||
|
if tools_call_args is None:
|
||||||
|
tools_call_args = []
|
||||||
|
if tools_call_name is None:
|
||||||
|
tools_call_name = []
|
||||||
|
if tools_call_ids is None:
|
||||||
|
tools_call_ids = []
|
||||||
|
|
||||||
|
self.role = role
|
||||||
|
self.completion_text = completion_text
|
||||||
|
self.result_chain = result_chain
|
||||||
|
self.tools_call_args = tools_call_args
|
||||||
|
self.tools_call_name = tools_call_name
|
||||||
|
self.tools_call_ids = tools_call_ids
|
||||||
|
self.raw_completion = raw_completion
|
||||||
|
self._new_record = _new_record
|
||||||
|
self.is_chunk = is_chunk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_text(self):
|
||||||
|
if self.result_chain:
|
||||||
|
return self.result_chain.get_plain_text()
|
||||||
|
return self._completion_text
|
||||||
|
|
||||||
|
@completion_text.setter
|
||||||
|
def completion_text(self, value):
|
||||||
|
if self.result_chain:
|
||||||
|
self.result_chain.chain = [
|
||||||
|
comp
|
||||||
|
for comp in self.result_chain.chain
|
||||||
|
if not isinstance(comp, Comp.Plain)
|
||||||
|
] # 清空 Plain 组件
|
||||||
|
self.result_chain.chain.insert(0, Comp.Plain(value))
|
||||||
|
else:
|
||||||
|
self._completion_text = value
|
||||||
|
|
||||||
|
def to_openai_tool_calls(self) -> List[Dict]:
|
||||||
|
"""将工具调用信息转换为 OpenAI 格式"""
|
||||||
|
ret = []
|
||||||
|
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
||||||
|
ret.append(
|
||||||
|
{
|
||||||
|
"id": self.tools_call_ids[idx],
|
||||||
|
"function": {
|
||||||
|
"name": self.tools_call_name[idx],
|
||||||
|
"arguments": json.dumps(tool_call_arg),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ret
|
||||||
@@ -1,7 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Dict, List, Awaitable
|
import os
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Dict, List, Awaitable, Literal, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.utils.log_pipe import LogPipe
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mcp
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||||
|
|
||||||
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||||
|
|
||||||
|
SUPPORTED_TYPES = [
|
||||||
|
"string",
|
||||||
|
"number",
|
||||||
|
"object",
|
||||||
|
"array",
|
||||||
|
"boolean",
|
||||||
|
] # json schema 支持的数据类型
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -13,25 +39,133 @@ class FuncTool:
|
|||||||
name: str
|
name: str
|
||||||
parameters: Dict
|
parameters: Dict
|
||||||
description: str
|
description: str
|
||||||
handler: Awaitable
|
handler: Awaitable = None
|
||||||
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||||
|
handler_module_path: str = None
|
||||||
|
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||||
|
|
||||||
|
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||||
|
"""
|
||||||
active: bool = True
|
active: bool = True
|
||||||
"""是否激活"""
|
"""是否激活"""
|
||||||
|
|
||||||
|
origin: Literal["local", "mcp"] = "local"
|
||||||
|
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||||
|
|
||||||
SUPPORTED_TYPES = [
|
# MCP 相关字段
|
||||||
"string",
|
mcp_server_name: str = None
|
||||||
"number",
|
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||||
"object",
|
mcp_client: MCPClient = None
|
||||||
"array",
|
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||||
"boolean",
|
|
||||||
] # json schema 支持的数据类型
|
def __repr__(self):
|
||||||
|
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||||
|
|
||||||
|
async def execute(self, **args) -> Any:
|
||||||
|
"""执行函数调用"""
|
||||||
|
if self.origin == "local":
|
||||||
|
if not self.handler:
|
||||||
|
raise Exception(f"Local function {self.name} has no handler")
|
||||||
|
return await self.handler(**args)
|
||||||
|
elif self.origin == "mcp":
|
||||||
|
if not self.mcp_client or not self.mcp_client.session:
|
||||||
|
raise Exception(f"MCP client for {self.name} is not available")
|
||||||
|
# 使用name属性而不是额外的mcp_tool_name
|
||||||
|
if ":" in self.name:
|
||||||
|
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||||
|
actual_tool_name = self.name.split(":")[-1]
|
||||||
|
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||||
|
else:
|
||||||
|
return await self.mcp_client.session.call_tool(self.name, args)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown function origin: {self.origin}")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize session and client objects
|
||||||
|
self.session: Optional[mcp.ClientSession] = None
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
self.name = None
|
||||||
|
self.active: bool = True
|
||||||
|
self.tools: List[mcp.Tool] = []
|
||||||
|
self.server_errlogs: List[str] = []
|
||||||
|
|
||||||
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||||
|
"""连接到 MCP 服务器
|
||||||
|
|
||||||
|
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||||
|
"""
|
||||||
|
cfg = mcp_server_config.copy()
|
||||||
|
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||||
|
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||||
|
cfg = cfg["mcpServers"][key_0]
|
||||||
|
cfg.pop("active", None) # Remove active flag from config
|
||||||
|
|
||||||
|
if "url" in cfg:
|
||||||
|
# SSE transport method
|
||||||
|
self._streams_context = sse_client(url=cfg["url"])
|
||||||
|
streams = await self._streams_context.__aenter__()
|
||||||
|
|
||||||
|
# Create a new client session
|
||||||
|
# self.session = await self._session_context.__aenter__()
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.ClientSession(*streams)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
server_params = mcp.StdioServerParameters(
|
||||||
|
**cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def callback(msg: str):
|
||||||
|
# 处理 MCP 服务的错误日志
|
||||||
|
self.server_errlogs.append(msg)
|
||||||
|
|
||||||
|
stdio_transport = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.stdio_client(
|
||||||
|
server_params,
|
||||||
|
errlog=LogPipe(
|
||||||
|
level=logging.ERROR,
|
||||||
|
logger=logger,
|
||||||
|
identifier=f"MCPServer-{name}",
|
||||||
|
callback=callback,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new client session
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.ClientSession(*stdio_transport)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.session.initialize()
|
||||||
|
|
||||||
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
|
"""List all tools from the server and save them to self.tools"""
|
||||||
|
response = await self.session.list_tools()
|
||||||
|
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||||
|
self.tools = response.tools
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""Clean up resources"""
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
|
||||||
|
|
||||||
class FuncCall:
|
class FuncCall:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.func_list: List[FuncTool] = []
|
self.func_list: List[FuncTool] = []
|
||||||
|
"""内部加载的 func tools"""
|
||||||
|
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||||
|
"""MCP 服务列表"""
|
||||||
|
self.mcp_service_queue = asyncio.Queue()
|
||||||
|
"""用于外部控制 MCP 服务的启停"""
|
||||||
|
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
return len(self.func_list) == 0
|
return len(self.func_list) == 0
|
||||||
@@ -43,14 +177,16 @@ class FuncCall:
|
|||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Awaitable,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""添加函数调用工具
|
||||||
为函数调用(function-calling / tools-use)添加工具。
|
|
||||||
|
|
||||||
@param name: 函数名
|
@param name: 函数名
|
||||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||||
@param desc: 函数描述
|
@param desc: 函数描述
|
||||||
@param func_obj: 处理函数
|
@param func_obj: 处理函数
|
||||||
"""
|
"""
|
||||||
|
# check if the tool has been added before
|
||||||
|
self.remove_func(name)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
"properties": {},
|
"properties": {},
|
||||||
@@ -67,13 +203,14 @@ class FuncCall:
|
|||||||
handler=handler,
|
handler=handler,
|
||||||
)
|
)
|
||||||
self.func_list.append(_func)
|
self.func_list.append(_func)
|
||||||
|
logger.info(f"添加函数调用工具: {name}")
|
||||||
|
|
||||||
def remove_func(self, name: str) -> None:
|
def remove_func(self, name: str) -> None:
|
||||||
"""
|
"""
|
||||||
删除一个函数调用工具。
|
删除一个函数调用工具。
|
||||||
"""
|
"""
|
||||||
for i, f in enumerate(self.func_list):
|
for i, f in enumerate(self.func_list):
|
||||||
if f["name"] == name:
|
if f.name == name:
|
||||||
self.func_list.pop(i)
|
self.func_list.pop(i)
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -83,24 +220,196 @@ class FuncCall:
|
|||||||
return f
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_func_desc_openai_style(self) -> list:
|
async def _init_mcp_clients(self) -> None:
|
||||||
|
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"weather": {
|
||||||
|
"command": "uv",
|
||||||
|
"args": [
|
||||||
|
"--directory",
|
||||||
|
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
|
||||||
|
"run",
|
||||||
|
"weather.py"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
|
||||||
|
|
||||||
|
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||||
|
if not os.path.exists(mcp_json_file):
|
||||||
|
# 配置文件不存在错误处理
|
||||||
|
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||||
|
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
mcp_server_json_obj: Dict[str, Dict] = json.load(
|
||||||
|
open(mcp_json_file, "r", encoding="utf-8")
|
||||||
|
)["mcpServers"]
|
||||||
|
|
||||||
|
for name in mcp_server_json_obj.keys():
|
||||||
|
cfg = mcp_server_json_obj[name]
|
||||||
|
if cfg.get("active", True):
|
||||||
|
event = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(name, cfg, event)
|
||||||
|
)
|
||||||
|
self.mcp_client_event[name] = event
|
||||||
|
|
||||||
|
async def mcp_service_selector(self):
|
||||||
|
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||||
|
|
||||||
|
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||||
|
|
||||||
|
{"type": "init"} 初始化所有MCP客户端
|
||||||
|
|
||||||
|
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||||
|
|
||||||
|
{"type": "terminate"} 终止所有MCP客户端
|
||||||
|
|
||||||
|
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
data = await self.mcp_service_queue.get()
|
||||||
|
if data["type"] == "init":
|
||||||
|
if "name" in data:
|
||||||
|
event = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(
|
||||||
|
data["name"], data["cfg"], event
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.mcp_client_event[data["name"]] = event
|
||||||
|
else:
|
||||||
|
await self._init_mcp_clients()
|
||||||
|
elif data["type"] == "terminate":
|
||||||
|
if "name" in data:
|
||||||
|
# await self._terminate_mcp_client(data["name"])
|
||||||
|
if data["name"] in self.mcp_client_event:
|
||||||
|
self.mcp_client_event[data["name"]].set()
|
||||||
|
self.mcp_client_event.pop(data["name"], None)
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (
|
||||||
|
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
for name in self.mcp_client_dict.keys():
|
||||||
|
# await self._terminate_mcp_client(name)
|
||||||
|
# self.mcp_client_event[name].set()
|
||||||
|
if name in self.mcp_client_event:
|
||||||
|
self.mcp_client_event[name].set()
|
||||||
|
self.mcp_client_event.pop(name, None)
|
||||||
|
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||||
|
|
||||||
|
async def _init_mcp_client_task_wrapper(
|
||||||
|
self, name: str, cfg: dict, event: asyncio.Event
|
||||||
|
) -> None:
|
||||||
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
|
try:
|
||||||
|
await self._init_mcp_client(name, cfg)
|
||||||
|
await event.wait()
|
||||||
|
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||||
|
|
||||||
|
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||||
|
"""初始化单个MCP客户端"""
|
||||||
|
try:
|
||||||
|
# 先清理之前的客户端,如果存在
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
|
||||||
|
mcp_client = MCPClient()
|
||||||
|
mcp_client.name = name
|
||||||
|
self.mcp_client_dict[name] = mcp_client
|
||||||
|
await mcp_client.connect_to_server(config, name)
|
||||||
|
tools_res = await mcp_client.list_tools_and_save()
|
||||||
|
tool_names = [tool.name for tool in tools_res.tools]
|
||||||
|
|
||||||
|
# 移除该MCP服务之前的工具(如有)
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||||
|
for tool in mcp_client.tools:
|
||||||
|
func_tool = FuncTool(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=tool.inputSchema,
|
||||||
|
description=tool.description,
|
||||||
|
origin="mcp",
|
||||||
|
mcp_server_name=name,
|
||||||
|
mcp_client=mcp_client,
|
||||||
|
)
|
||||||
|
self.func_list.append(func_tool)
|
||||||
|
|
||||||
|
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||||
|
# 发生错误时确保客户端被清理
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _terminate_mcp_client(self, name: str) -> None:
|
||||||
|
"""关闭并清理MCP客户端"""
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
try:
|
||||||
|
# 关闭MCP连接
|
||||||
|
await self.mcp_client_dict[name].cleanup()
|
||||||
|
del self.mcp_client_dict[name]
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||||
|
# 移除关联的FuncTool
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
logger.info(f"已关闭 MCP 服务 {name}")
|
||||||
|
|
||||||
|
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||||
"""
|
"""
|
||||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||||
"""
|
"""
|
||||||
_l = []
|
_l = []
|
||||||
|
# 处理所有工具(包括本地和MCP工具)
|
||||||
for f in self.func_list:
|
for f in self.func_list:
|
||||||
if not f.active:
|
if not f.active:
|
||||||
continue
|
continue
|
||||||
_l.append(
|
func_ = {
|
||||||
{
|
"type": "function",
|
||||||
"type": "function",
|
"function": {
|
||||||
"function": {
|
"name": f.name,
|
||||||
"name": f.name,
|
# "parameters": f.parameters,
|
||||||
"parameters": f.parameters,
|
"description": f.description,
|
||||||
"description": f.description,
|
},
|
||||||
},
|
}
|
||||||
}
|
func_["function"]["parameters"] = f.parameters
|
||||||
)
|
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||||
|
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||||
|
del func_["function"]["parameters"]
|
||||||
|
_l.append(func_)
|
||||||
return _l
|
return _l
|
||||||
|
|
||||||
def get_func_desc_anthropic_style(self) -> list:
|
def get_func_desc_anthropic_style(self) -> list:
|
||||||
@@ -137,7 +446,13 @@ class FuncCall:
|
|||||||
|
|
||||||
# 检查并添加非空的properties参数
|
# 检查并添加非空的properties参数
|
||||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||||
|
params = copy.deepcopy(params)
|
||||||
if params.get("properties", {}):
|
if params.get("properties", {}):
|
||||||
|
properties = params["properties"]
|
||||||
|
for key, value in properties.items():
|
||||||
|
if "default" in value:
|
||||||
|
del value["default"]
|
||||||
|
params["properties"] = properties
|
||||||
func_declaration["parameters"] = params
|
func_declaration["parameters"] = params
|
||||||
|
|
||||||
tools.append(func_declaration)
|
tools.append(func_declaration)
|
||||||
@@ -153,9 +468,9 @@ class FuncCall:
|
|||||||
continue
|
continue
|
||||||
_l.append(
|
_l.append(
|
||||||
{
|
{
|
||||||
"name": f["name"],
|
"name": f.name,
|
||||||
"parameters": f["parameters"],
|
"parameters": f.parameters,
|
||||||
"description": f["description"],
|
"description": f.description,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||||
@@ -205,14 +520,11 @@ class FuncCall:
|
|||||||
func_name = tool["name"]
|
func_name = tool["name"]
|
||||||
args = tool["args"]
|
args = tool["args"]
|
||||||
# 调用函数
|
# 调用函数
|
||||||
tool_callable = None
|
func_tool = self.get_func(func_name)
|
||||||
for func in self.func_list:
|
if not func_tool:
|
||||||
if func.name == func_name:
|
|
||||||
tool_callable = func.star_handler_metadata.handler
|
|
||||||
break
|
|
||||||
if not tool_callable:
|
|
||||||
raise Exception(f"Request function {func_name} not found.")
|
raise Exception(f"Request function {func_name} not found.")
|
||||||
ret = await tool_callable(**args)
|
|
||||||
|
ret = await func_tool.execute(**args)
|
||||||
if ret:
|
if ret:
|
||||||
tool_call_result.append(str(ret))
|
tool_call_result.append(str(ret))
|
||||||
return tool_call_result, True
|
return tool_call_result, True
|
||||||
@@ -222,3 +534,8 @@ class FuncCall:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.func_list)
|
return str(self.func_list)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
for name in self.mcp_client_dict.keys():
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
import asyncio
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||||
from .entites import ProviderType
|
from .entities import ProviderType
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from .register import provider_cls_map, llm_tools
|
from .register import provider_cls_map, llm_tools
|
||||||
@@ -127,14 +128,19 @@ class ProviderManager:
|
|||||||
if self.tts_enabled and not self.curr_tts_provider_inst:
|
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||||
|
|
||||||
|
# 初始化 MCP Client 连接
|
||||||
|
asyncio.create_task(
|
||||||
|
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
||||||
|
)
|
||||||
|
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
||||||
|
|
||||||
async def load_provider(self, provider_config: dict):
|
async def load_provider(self, provider_config: dict):
|
||||||
if not provider_config["enable"]:
|
if not provider_config["enable"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ..."
|
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..."
|
||||||
)
|
)
|
||||||
logger.debug(f"Provider Config: {provider_config}")
|
|
||||||
|
|
||||||
# 动态导入
|
# 动态导入
|
||||||
try:
|
try:
|
||||||
@@ -192,6 +198,10 @@ class ProviderManager:
|
|||||||
from .sources.fishaudio_tts_api_source import (
|
from .sources.fishaudio_tts_api_source import (
|
||||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||||
)
|
)
|
||||||
|
case "dashscope_tts":
|
||||||
|
from .sources.dashscope_tts import (
|
||||||
|
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||||
|
)
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.critical(
|
logger.critical(
|
||||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||||
@@ -300,10 +310,42 @@ class ProviderManager:
|
|||||||
|
|
||||||
if len(self.provider_insts) == 0:
|
if len(self.provider_insts) == 0:
|
||||||
self.curr_provider_inst = None
|
self.curr_provider_inst = None
|
||||||
|
elif (
|
||||||
|
self.curr_provider_inst is None
|
||||||
|
and len(self.provider_insts) > 0
|
||||||
|
and self.provider_enabled
|
||||||
|
):
|
||||||
|
self.curr_provider_inst = self.provider_insts[0]
|
||||||
|
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||||
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.stt_provider_insts) == 0:
|
if len(self.stt_provider_insts) == 0:
|
||||||
self.curr_stt_provider_inst = None
|
self.curr_stt_provider_inst = None
|
||||||
|
elif (
|
||||||
|
self.curr_stt_provider_inst is None
|
||||||
|
and len(self.stt_provider_insts) > 0
|
||||||
|
and self.stt_enabled
|
||||||
|
):
|
||||||
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||||
|
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||||
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.tts_provider_insts) == 0:
|
if len(self.tts_provider_insts) == 0:
|
||||||
self.curr_tts_provider_inst = None
|
self.curr_tts_provider_inst = None
|
||||||
|
elif (
|
||||||
|
self.curr_tts_provider_inst is None
|
||||||
|
and len(self.tts_provider_insts) > 0
|
||||||
|
and self.tts_enabled
|
||||||
|
):
|
||||||
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||||
|
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||||
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||||
|
)
|
||||||
|
|
||||||
def get_insts(self):
|
def get_insts(self):
|
||||||
return self.provider_insts
|
return self.provider_insts
|
||||||
@@ -340,3 +382,5 @@ class ProviderManager:
|
|||||||
for provider_inst in self.provider_insts:
|
for provider_inst in self.provider_insts:
|
||||||
if hasattr(provider_inst, "terminate"):
|
if hasattr(provider_inst, "terminate"):
|
||||||
await provider_inst.terminate()
|
await provider_inst.terminate()
|
||||||
|
# 清理 MCP Client 连接
|
||||||
|
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import abc
|
import abc
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from typing import TypedDict
|
from typing import TypedDict, AsyncGenerator
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from astrbot.core.provider.entites import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@@ -90,6 +90,7 @@ class Provider(AbstractProvider):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts: List = None,
|
contexts: List = None,
|
||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
|
tool_calls_result: ToolCallsResult = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||||
@@ -100,13 +101,42 @@ class Provider(AbstractProvider):
|
|||||||
image_urls: 图片 URL 列表
|
image_urls: 图片 URL 列表
|
||||||
tools: Function-calling 工具
|
tools: Function-calling 工具
|
||||||
contexts: 上下文
|
contexts: 上下文
|
||||||
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||||
kwargs: 其他参数
|
kwargs: 其他参数
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
...
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = None,
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts: List = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
tool_calls_result: ToolCallsResult = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
session_id: 会话 ID(此属性已经被废弃)
|
||||||
|
image_urls: 图片 URL 列表
|
||||||
|
tools: Function-calling 工具
|
||||||
|
contexts: 上下文
|
||||||
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||||
|
kwargs: 其他参数
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||||
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
async def pop_record(self, context: List):
|
async def pop_record(self, context: List):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from .entites import ProviderMetaData, ProviderType
|
from .entities import ProviderMetaData, ProviderType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from .func_tool_manager import FuncCall
|
from .func_tool_manager import FuncCall
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from astrbot.api.provider import Provider, Personality
|
|||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.provider.entites import LLMResponse
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||||
from .openai_source import ProviderOpenAIOfficial
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
|
|
||||||
|
|
||||||
@@ -72,18 +73,22 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
# text completion
|
# text completion
|
||||||
completion_text = str(content.text).strip()
|
completion_text = str(content.text).strip()
|
||||||
llm_response.completion_text = completion_text
|
# llm_response.completion_text = completion_text
|
||||||
|
llm_response.result_chain = MessageChain().message(completion_text)
|
||||||
|
|
||||||
# Anthropic每次只返回一个函数调用
|
# Anthropic每次只返回一个函数调用
|
||||||
if completion.stop_reason == "tool_use":
|
if completion.stop_reason == "tool_use":
|
||||||
# tools call (function calling)
|
# tools call (function calling)
|
||||||
args_ls = []
|
args_ls = []
|
||||||
func_name_ls = []
|
func_name_ls = []
|
||||||
|
tool_use_ids = []
|
||||||
func_name_ls.append(content.name)
|
func_name_ls.append(content.name)
|
||||||
args_ls.append(content.input)
|
args_ls.append(content.input)
|
||||||
|
tool_use_ids.append(content.id)
|
||||||
llm_response.role = "tool"
|
llm_response.role = "tool"
|
||||||
llm_response.tools_call_args = args_ls
|
llm_response.tools_call_args = args_ls
|
||||||
llm_response.tools_call_name = func_name_ls
|
llm_response.tools_call_name = func_name_ls
|
||||||
|
llm_response.tools_call_ids = tool_use_ids
|
||||||
|
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
@@ -101,6 +106,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts=[],
|
contexts=[],
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
|
tool_calls_result: ToolCallsResult = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
if not prompt:
|
if not prompt:
|
||||||
@@ -113,6 +119,10 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
if "_no_save" in part:
|
if "_no_save" in part:
|
||||||
del part["_no_save"]
|
del part["_no_save"]
|
||||||
|
|
||||||
|
if tool_calls_result:
|
||||||
|
# 暂时这样写。
|
||||||
|
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
@@ -137,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
messages=context_query, **model_config
|
messages=context_query, **model_config
|
||||||
)
|
)
|
||||||
llm_response = LLMResponse("assistant")
|
llm_response = LLMResponse("assistant")
|
||||||
llm_response.completion_text = response.content[0].text
|
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
||||||
llm_response.raw_completion = response
|
llm_response.raw_completion = response
|
||||||
return llm_response
|
return llm_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -152,6 +162,33 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
|||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=...,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=...,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
# 调用 text_chat 模拟流式
|
||||||
|
llm_response = await self.text_chat(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
)
|
||||||
|
llm_response.is_chunk = True
|
||||||
|
yield llm_response
|
||||||
|
llm_response.is_chunk = False
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||||
"""组装上下文,支持文本和图片"""
|
"""组装上下文,支持文本和图片"""
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
|
import re
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
from typing import List
|
from typing import List
|
||||||
from .. import Provider, Personality
|
from .. import Provider, Personality
|
||||||
from ..entites import LLMResponse
|
from ..entities import LLMResponse
|
||||||
from ..func_tool_manager import FuncCall
|
from ..func_tool_manager import FuncCall
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from .openai_source import ProviderOpenAIOfficial
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
from astrbot.core import logger, sp
|
from astrbot.core import logger, sp
|
||||||
from dashscope import Application
|
from dashscope import Application
|
||||||
@@ -40,11 +42,28 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
raise Exception("阿里云百炼 APP 类型不能为空。")
|
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||||
self.model_name = "dashscope"
|
self.model_name = "dashscope"
|
||||||
self.variables: dict = provider_config.get("variables", {})
|
self.variables: dict = provider_config.get("variables", {})
|
||||||
|
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||||
|
self.output_reference = self.rag_options.get("output_reference", False)
|
||||||
|
self.rag_options = self.rag_options.copy()
|
||||||
|
self.rag_options.pop("output_reference", None)
|
||||||
|
|
||||||
self.timeout = provider_config.get("timeout", 120)
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
if isinstance(self.timeout, str):
|
if isinstance(self.timeout, str):
|
||||||
self.timeout = int(self.timeout)
|
self.timeout = int(self.timeout)
|
||||||
|
|
||||||
|
def has_rag_options(self):
|
||||||
|
"""判断是否有 RAG 选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否有 RAG 选项
|
||||||
|
"""
|
||||||
|
if self.rag_options and (
|
||||||
|
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||||
|
or len(self.rag_options.get("file_ids", [])) > 0
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def text_chat(
|
async def text_chat(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -62,7 +81,10 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
session_var = session_vars.get(session_id, {})
|
session_var = session_vars.get(session_id, {})
|
||||||
payload_vars.update(session_var)
|
payload_vars.update(session_var)
|
||||||
|
|
||||||
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
|
if (
|
||||||
|
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||||
|
and not self.has_rag_options()
|
||||||
|
):
|
||||||
# 支持多轮对话的
|
# 支持多轮对话的
|
||||||
new_record = {"role": "user", "content": prompt}
|
new_record = {"role": "user", "content": prompt}
|
||||||
if image_urls:
|
if image_urls:
|
||||||
@@ -75,23 +97,31 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
if "_no_save" in part:
|
if "_no_save" in part:
|
||||||
del part["_no_save"]
|
del part["_no_save"]
|
||||||
# 调用阿里云百炼 API
|
# 调用阿里云百炼 API
|
||||||
|
payload = {
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"messages": context_query,
|
||||||
|
"biz_params": payload_vars or None,
|
||||||
|
}
|
||||||
partial = functools.partial(
|
partial = functools.partial(
|
||||||
Application.call,
|
Application.call,
|
||||||
app_id=self.app_id,
|
**payload,
|
||||||
api_key=self.api_key,
|
|
||||||
messages=context_query,
|
|
||||||
biz_params=payload_vars or None,
|
|
||||||
)
|
)
|
||||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||||
else:
|
else:
|
||||||
# 不支持多轮对话的
|
# 不支持多轮对话的
|
||||||
# 调用阿里云百炼 API
|
# 调用阿里云百炼 API
|
||||||
|
payload = {
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"prompt": prompt,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"biz_params": payload_vars or None,
|
||||||
|
}
|
||||||
|
if self.rag_options:
|
||||||
|
payload["rag_options"] = self.rag_options
|
||||||
partial = functools.partial(
|
partial = functools.partial(
|
||||||
Application.call,
|
Application.call,
|
||||||
app_id=self.app_id,
|
**payload,
|
||||||
promtp=prompt,
|
|
||||||
api_key=self.api_key,
|
|
||||||
biz_params=payload_vars or None,
|
|
||||||
)
|
)
|
||||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||||
|
|
||||||
@@ -103,11 +133,56 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|||||||
)
|
)
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
role="err",
|
role="err",
|
||||||
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
result_chain=MessageChain().message(
|
||||||
|
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
output_text = response.output.get("text", "")
|
output_text = response.output.get("text", "")
|
||||||
return LLMResponse(role="assistant", completion_text=output_text)
|
# RAG 引用脚标格式化
|
||||||
|
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||||
|
if self.output_reference and response.output.get("doc_references", None):
|
||||||
|
ref_str = ""
|
||||||
|
for ref in response.output.get("doc_references", []):
|
||||||
|
ref_title = (
|
||||||
|
ref.get("title", "")
|
||||||
|
if ref.get("title")
|
||||||
|
else ref.get("doc_name", "")
|
||||||
|
)
|
||||||
|
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
||||||
|
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||||
|
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
llm_response.result_chain = MessageChain().message(output_text)
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=...,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=...,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
# 调用 text_chat 模拟流式
|
||||||
|
llm_response = await self.text_chat(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
)
|
||||||
|
llm_response.is_chunk = True
|
||||||
|
yield llm_response
|
||||||
|
llm_response.is_chunk = False
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
async def forget(self, session_id):
|
async def forget(self, session_id):
|
||||||
return True
|
return True
|
||||||
|
|||||||
38
astrbot/core/provider/sources/dashscope_tts.py
Normal file
38
astrbot/core/provider/sources/dashscope_tts.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import dashscope
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from dashscope.audio.tts_v2 import *
|
||||||
|
from ..provider import TTSProvider
|
||||||
|
from ..entities import ProviderType
|
||||||
|
from ..register import register_provider_adapter
|
||||||
|
|
||||||
|
|
||||||
|
@register_provider_adapter(
|
||||||
|
"dashscope_tts", "Dashscope TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||||
|
)
|
||||||
|
class ProviderDashscopeTTSAPI(TTSProvider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_config: dict,
|
||||||
|
provider_settings: dict,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(provider_config, provider_settings)
|
||||||
|
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||||
|
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||||
|
self.set_model(provider_config.get("model", None))
|
||||||
|
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||||
|
dashscope.api_key = self.chosen_api_key
|
||||||
|
|
||||||
|
async def get_audio(self, text: str) -> str:
|
||||||
|
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||||
|
self.synthesizer = SpeechSynthesizer(
|
||||||
|
model=self.get_model(),
|
||||||
|
voice=self.voice,
|
||||||
|
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||||
|
)
|
||||||
|
audio = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.synthesizer.call, text, self.timeout_ms
|
||||||
|
)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(audio)
|
||||||
|
return path
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
|
import astrbot.core.message.components as Comp
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from .. import Provider, Personality
|
from .. import Provider, Personality
|
||||||
from ..entites import LLMResponse
|
from ..entities import LLMResponse
|
||||||
from ..func_tool_manager import FuncCall
|
from ..func_tool_manager import FuncCall
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||||
from astrbot.core import logger, sp
|
from astrbot.core import logger, sp
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||||
@@ -30,7 +33,6 @@ class ProviderDify(Provider):
|
|||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise Exception("Dify API Key 不能为空。")
|
raise Exception("Dify API Key 不能为空。")
|
||||||
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||||
self.api_client = DifyAPIClient(self.api_key, api_base)
|
|
||||||
self.api_type = provider_config.get("dify_api_type", "")
|
self.api_type = provider_config.get("dify_api_type", "")
|
||||||
if not self.api_type:
|
if not self.api_type:
|
||||||
raise Exception("Dify API 类型不能为空。")
|
raise Exception("Dify API 类型不能为空。")
|
||||||
@@ -41,15 +43,19 @@ class ProviderDify(Provider):
|
|||||||
self.dify_query_input_key = provider_config.get(
|
self.dify_query_input_key = provider_config.get(
|
||||||
"dify_query_input_key", "astrbot_text_query"
|
"dify_query_input_key", "astrbot_text_query"
|
||||||
)
|
)
|
||||||
self.variables: dict = provider_config.get("variables", {})
|
|
||||||
if not self.dify_query_input_key:
|
if not self.dify_query_input_key:
|
||||||
self.dify_query_input_key = "astrbot_text_query"
|
self.dify_query_input_key = "astrbot_text_query"
|
||||||
|
if not self.workflow_output_key:
|
||||||
|
self.workflow_output_key = "astrbot_wf_output"
|
||||||
|
self.variables: dict = provider_config.get("variables", {})
|
||||||
self.timeout = provider_config.get("timeout", 120)
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
if isinstance(self.timeout, str):
|
if isinstance(self.timeout, str):
|
||||||
self.timeout = int(self.timeout)
|
self.timeout = int(self.timeout)
|
||||||
self.conversation_ids = {}
|
self.conversation_ids = {}
|
||||||
"""记录当前 session id 的对话 ID"""
|
"""记录当前 session id 的对话 ID"""
|
||||||
|
|
||||||
|
self.api_client = DifyAPIClient(self.api_key, api_base)
|
||||||
|
|
||||||
async def text_chat(
|
async def text_chat(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -65,26 +71,27 @@ class ProviderDify(Provider):
|
|||||||
|
|
||||||
files_payload = []
|
files_payload = []
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
if image_url.startswith("http"):
|
image_path = (
|
||||||
image_path = await download_image_by_url(image_url)
|
await download_image_by_url(image_url)
|
||||||
file_response = await self.api_client.file_upload(
|
if image_url.startswith("http")
|
||||||
image_path, user=session_id
|
else image_url
|
||||||
|
)
|
||||||
|
file_response = await self.api_client.file_upload(
|
||||||
|
image_path, user=session_id
|
||||||
|
)
|
||||||
|
logger.debug(f"Dify 上传图片响应:{file_response}")
|
||||||
|
if "id" not in file_response:
|
||||||
|
logger.warning(
|
||||||
|
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
||||||
)
|
)
|
||||||
if "id" not in file_response:
|
continue
|
||||||
logger.warning(
|
files_payload.append(
|
||||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
{
|
||||||
)
|
"type": "image",
|
||||||
continue
|
"transfer_method": "local_file",
|
||||||
files_payload.append(
|
"upload_file_id": file_response["id"],
|
||||||
{
|
}
|
||||||
"type": "image",
|
)
|
||||||
"transfer_method": "local_file",
|
|
||||||
"upload_file_id": file_response["id"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# TODO: 处理更多情况
|
|
||||||
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
|
|
||||||
|
|
||||||
# 获得会话变量
|
# 获得会话变量
|
||||||
payload_vars = self.variables.copy()
|
payload_vars = self.variables.copy()
|
||||||
@@ -95,7 +102,10 @@ class ProviderDify(Provider):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
match self.api_type:
|
match self.api_type:
|
||||||
case "chat" | "agent":
|
case "chat" | "agent" | "chatflow":
|
||||||
|
if not prompt:
|
||||||
|
prompt = "请描述这张图片。"
|
||||||
|
|
||||||
async for chunk in self.api_client.chat_messages(
|
async for chunk in self.api_client.chat_messages(
|
||||||
inputs={
|
inputs={
|
||||||
**payload_vars,
|
**payload_vars,
|
||||||
@@ -148,8 +158,9 @@ class ProviderDify(Provider):
|
|||||||
)
|
)
|
||||||
case "workflow_finished":
|
case "workflow_finished":
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。"
|
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
|
||||||
)
|
)
|
||||||
|
logger.debug(f"Dify 工作流结果:{chunk}")
|
||||||
if chunk["data"]["error"]:
|
if chunk["data"]["error"]:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
f"Dify 工作流出现错误:{chunk['data']['error']}"
|
||||||
@@ -164,9 +175,7 @@ class ProviderDify(Provider):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
|
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
|
||||||
)
|
)
|
||||||
result = chunk["data"]["outputs"][
|
result = chunk
|
||||||
self.workflow_output_key
|
|
||||||
]
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -176,7 +185,81 @@ class ProviderDify(Provider):
|
|||||||
if not result:
|
if not result:
|
||||||
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
|
||||||
|
|
||||||
return LLMResponse(role="assistant", completion_text=result)
|
chain = await self.parse_dify_result(result)
|
||||||
|
|
||||||
|
return LLMResponse(role="assistant", result_chain=chain)
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=...,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=...,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
# 调用 text_chat 模拟流式
|
||||||
|
llm_response = await self.text_chat(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
)
|
||||||
|
llm_response.is_chunk = True
|
||||||
|
yield llm_response
|
||||||
|
llm_response.is_chunk = False
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
# Chat
|
||||||
|
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||||
|
|
||||||
|
async def parse_file(item: dict) -> Comp:
|
||||||
|
match item["type"]:
|
||||||
|
case "image":
|
||||||
|
return Comp.Image(file=item["url"], url=item["url"])
|
||||||
|
case "audio":
|
||||||
|
# 仅支持 wav
|
||||||
|
path = f"data/temp/{item['filename']}.wav"
|
||||||
|
await download_file(item["url"], path)
|
||||||
|
return Comp.Image(file=item["url"], url=item["url"])
|
||||||
|
case "video":
|
||||||
|
return Comp.Video(file=item["url"])
|
||||||
|
case _:
|
||||||
|
return Comp.File(name=item["filename"], file=item["url"])
|
||||||
|
|
||||||
|
output = chunk["data"]["outputs"][self.workflow_output_key]
|
||||||
|
chains = []
|
||||||
|
if isinstance(output, str):
|
||||||
|
# 纯文本输出
|
||||||
|
chains.append(Comp.Plain(output))
|
||||||
|
elif isinstance(output, list):
|
||||||
|
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
|
||||||
|
for item in output:
|
||||||
|
# handle Array[File]
|
||||||
|
if (
|
||||||
|
not isinstance(item, dict)
|
||||||
|
or item.get("dify_model_identity", "") != "__dify__file__"
|
||||||
|
):
|
||||||
|
chains.append(Comp.Plain(str(output)))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
chains.append(Comp.Plain(str(output)))
|
||||||
|
|
||||||
|
# scan file
|
||||||
|
files = chunk["data"].get("files", [])
|
||||||
|
for item in files:
|
||||||
|
comp = await parse_file(item)
|
||||||
|
chains.append(comp)
|
||||||
|
|
||||||
|
return MessageChain(chain=chains)
|
||||||
|
|
||||||
async def forget(self, session_id):
|
async def forget(self, session_id):
|
||||||
self.conversation_ids[session_id] = ""
|
self.conversation_ids[session_id] = ""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import edge_tts
|
|||||||
import subprocess
|
import subprocess
|
||||||
import asyncio
|
import asyncio
|
||||||
from ..provider import TTSProvider
|
from ..provider import TTSProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
@@ -35,6 +35,8 @@ class ProviderEdgeTTS(TTSProvider):
|
|||||||
self.pitch = provider_config.get("pitch", None)
|
self.pitch = provider_config.get("pitch", None)
|
||||||
self.timeout = provider_config.get("timeout", 30)
|
self.timeout = provider_config.get("timeout", 30)
|
||||||
|
|
||||||
|
self.proxy = os.getenv("https_proxy", None)
|
||||||
|
|
||||||
self.set_model("edge_tts")
|
self.set_model("edge_tts")
|
||||||
|
|
||||||
async def get_audio(self, text: str) -> str:
|
async def get_audio(self, text: str) -> str:
|
||||||
@@ -42,7 +44,7 @@ class ProviderEdgeTTS(TTSProvider):
|
|||||||
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||||
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
||||||
|
|
||||||
# 构建Edge TTS参数
|
# 构建 Edge TTS 参数
|
||||||
kwargs = {"text": text, "voice": self.voice}
|
kwargs = {"text": text, "voice": self.voice}
|
||||||
if self.rate:
|
if self.rate:
|
||||||
kwargs["rate"] = self.rate
|
kwargs["rate"] = self.rate
|
||||||
@@ -52,12 +54,20 @@ class ProviderEdgeTTS(TTSProvider):
|
|||||||
kwargs["pitch"] = self.pitch
|
kwargs["pitch"] = self.pitch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
communicate = edge_tts.Communicate(**kwargs)
|
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
|
||||||
await communicate.save(mp3_path)
|
await communicate.save(mp3_path)
|
||||||
|
|
||||||
# 使用ffmpeg将MP3转换为标准WAV格式
|
try:
|
||||||
_ = await asyncio.create_subprocess_exec(
|
from pyffmpeg import FFmpeg
|
||||||
[
|
|
||||||
|
ff = FFmpeg()
|
||||||
|
ff.convert(input=mp3_path, output=wav_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||||
|
# use ffmpeg command line
|
||||||
|
|
||||||
|
# 使用ffmpeg将MP3转换为标准WAV格式
|
||||||
|
p = await asyncio.create_subprocess_exec(
|
||||||
"ffmpeg",
|
"ffmpeg",
|
||||||
"-y", # 覆盖输出文件
|
"-y", # 覆盖输出文件
|
||||||
"-i",
|
"-i",
|
||||||
@@ -68,11 +78,20 @@ class ProviderEdgeTTS(TTSProvider):
|
|||||||
"24000", # 采样率24kHz (适合微信语音)
|
"24000", # 采样率24kHz (适合微信语音)
|
||||||
"-ac",
|
"-ac",
|
||||||
"1", # 单声道
|
"1", # 单声道
|
||||||
|
"-af",
|
||||||
|
"apad=pad_dur=2", # 确保输出时长准确
|
||||||
|
"-fflags",
|
||||||
|
"+genpts", # 强制生成时间戳
|
||||||
|
"-hide_banner", # 隐藏版本信息
|
||||||
wav_path, # 输出文件
|
wav_path, # 输出文件
|
||||||
],
|
stdout=subprocess.PIPE,
|
||||||
capture_output=True,
|
stderr=subprocess.PIPE,
|
||||||
check=True,
|
)
|
||||||
)
|
# 等待进程完成并获取输出
|
||||||
|
stdout, stderr = await p.communicate()
|
||||||
|
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
|
||||||
|
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
|
||||||
|
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
|
||||||
|
|
||||||
os.remove(mp3_path)
|
os.remove(mp3_path)
|
||||||
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
|
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
|
||||||
@@ -82,13 +101,15 @@ class ProviderEdgeTTS(TTSProvider):
|
|||||||
raise RuntimeError("生成的WAV文件不存在或为空")
|
raise RuntimeError("生成的WAV文件不存在或为空")
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logger.error(f"FFmpeg转换失败: {e.stderr.decode() if e.stderr else str(e)}")
|
logger.error(
|
||||||
|
f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if os.path.exists(mp3_path):
|
if os.path.exists(mp3_path):
|
||||||
os.remove(mp3_path)
|
os.remove(mp3_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
raise RuntimeError(f"FFmpeg转换失败: {str(e)}")
|
raise RuntimeError(f"FFmpeg 转换失败: {str(e)}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"音频生成失败: {str(e)}")
|
logger.error(f"音频生成失败: {str(e)}")
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, conint
|
|||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
from ..provider import TTSProvider
|
from ..provider import TTSProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,75 +1,55 @@
|
|||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import aiohttp
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from typing import Dict, List, Optional
|
||||||
from astrbot.core.db import BaseDatabase
|
from collections.abc import AsyncGenerator
|
||||||
from astrbot.api.provider import Provider, Personality
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
from google.genai.errors import APIError
|
||||||
|
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
from astrbot.api.provider import Personality, Provider
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from typing import List
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.provider.entites import LLMResponse
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleGoogleGenAIClient:
|
class SuppressNonTextPartsWarning(logging.Filter):
|
||||||
def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None:
|
"""过滤 Gemini SDK 中的非文本部分警告"""
|
||||||
self.api_key = api_key
|
|
||||||
if api_base.endswith("/"):
|
|
||||||
self.api_base = api_base[:-1]
|
|
||||||
else:
|
|
||||||
self.api_base = api_base
|
|
||||||
self.client = aiohttp.ClientSession(trust_env=True)
|
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
async def models_list(self) -> List[str]:
|
def filter(self, record):
|
||||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
return "there are non-text parts in the response" not in record.getMessage()
|
||||||
async with self.client.get(request_url, timeout=self.timeout) as resp:
|
|
||||||
response = await resp.json()
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for model in response["models"]:
|
|
||||||
if "generateContent" in model["supportedGenerationMethods"]:
|
|
||||||
models.append(model["name"].replace("models/", ""))
|
|
||||||
return models
|
|
||||||
|
|
||||||
async def generate_content(
|
logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning())
|
||||||
self,
|
|
||||||
contents: List[dict],
|
|
||||||
model: str = "gemini-1.5-flash",
|
|
||||||
system_instruction: str = "",
|
|
||||||
tools: dict = None,
|
|
||||||
):
|
|
||||||
payload = {}
|
|
||||||
if system_instruction:
|
|
||||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
|
||||||
if tools:
|
|
||||||
payload["tools"] = [tools]
|
|
||||||
payload["contents"] = contents
|
|
||||||
logger.debug(f"payload: {payload}")
|
|
||||||
request_url = (
|
|
||||||
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
|
||||||
)
|
|
||||||
async with self.client.post(
|
|
||||||
request_url, json=payload, timeout=self.timeout
|
|
||||||
) as resp:
|
|
||||||
if "application/json" in resp.headers.get("Content-Type"):
|
|
||||||
try:
|
|
||||||
response = await resp.json()
|
|
||||||
except Exception as e:
|
|
||||||
text = await resp.text()
|
|
||||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
|
||||||
raise e
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
text = await resp.text()
|
|
||||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
|
||||||
raise Exception("Gemini 返回了非 json 数据: ")
|
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
@register_provider_adapter(
|
||||||
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
|
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
|
||||||
)
|
)
|
||||||
class ProviderGoogleGenAI(Provider):
|
class ProviderGoogleGenAI(Provider):
|
||||||
|
CATEGORY_MAPPING = {
|
||||||
|
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||||
|
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||||
|
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||||
|
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||||
|
}
|
||||||
|
|
||||||
|
THRESHOLD_MAPPING = {
|
||||||
|
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||||
|
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_config: dict,
|
provider_config: dict,
|
||||||
@@ -85,99 +65,384 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
db_helper,
|
db_helper,
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.chosen_api_key = None
|
|
||||||
self.api_keys: List = provider_config.get("key", [])
|
self.api_keys: List = provider_config.get("key", [])
|
||||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||||
self.timeout = provider_config.get("timeout", 180)
|
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||||
if isinstance(self.timeout, str):
|
|
||||||
self.timeout = int(self.timeout)
|
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||||
self.client = SimpleGoogleGenAIClient(
|
if self.api_base and self.api_base.endswith("/"):
|
||||||
api_key=self.chosen_api_key,
|
self.api_base = self.api_base[:-1]
|
||||||
api_base=provider_config.get("api_base", None),
|
|
||||||
timeout=self.timeout,
|
self._init_client()
|
||||||
)
|
|
||||||
self.set_model(provider_config["model_config"]["model"])
|
self.set_model(provider_config["model_config"]["model"])
|
||||||
|
self._init_safety_settings()
|
||||||
|
|
||||||
async def get_models(self):
|
def _init_client(self) -> None:
|
||||||
return await self.client.models_list()
|
"""初始化Gemini客户端"""
|
||||||
|
self.client = genai.Client(
|
||||||
|
api_key=self.chosen_api_key,
|
||||||
|
http_options=types.HttpOptions(
|
||||||
|
base_url=self.api_base,
|
||||||
|
timeout=self.timeout * 1000, # 毫秒
|
||||||
|
),
|
||||||
|
).aio
|
||||||
|
|
||||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
def _init_safety_settings(self) -> None:
|
||||||
tool = None
|
"""初始化安全设置"""
|
||||||
if tools:
|
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||||
tool = tools.get_func_desc_google_genai_style()
|
self.safety_settings = [
|
||||||
if not tool:
|
types.SafetySetting(
|
||||||
tool = None
|
category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str]
|
||||||
|
)
|
||||||
|
for config_key, harm_category in self.CATEGORY_MAPPING.items()
|
||||||
|
if (threshold_str := user_safety_config.get(config_key))
|
||||||
|
and threshold_str in self.THRESHOLD_MAPPING
|
||||||
|
]
|
||||||
|
|
||||||
system_instruction = ""
|
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
|
||||||
|
"""处理API错误,返回是否需要重试"""
|
||||||
|
if e.code == 429 or "API key not valid" in e.message:
|
||||||
|
keys.remove(self.chosen_api_key)
|
||||||
|
if len(keys) > 0:
|
||||||
|
self.set_key(random.choice(keys))
|
||||||
|
logger.info(
|
||||||
|
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||||
|
)
|
||||||
|
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _prepare_query_config(
|
||||||
|
self,
|
||||||
|
payloads: dict,
|
||||||
|
tools: Optional[FuncCall] = None,
|
||||||
|
system_instruction: Optional[str] = None,
|
||||||
|
modalities: Optional[List[str]] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
) -> types.GenerateContentConfig:
|
||||||
|
"""准备查询配置"""
|
||||||
|
if not modalities:
|
||||||
|
modalities = ["Text"]
|
||||||
|
|
||||||
|
# 流式输出不支持图片模态
|
||||||
|
if (
|
||||||
|
self.provider_settings.get("streaming_response", False)
|
||||||
|
and "Image" in modalities
|
||||||
|
):
|
||||||
|
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||||
|
modalities = ["Text"]
|
||||||
|
|
||||||
|
tool_list = None
|
||||||
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||||
|
native_search = self.provider_config.get("gm_native_search", False)
|
||||||
|
|
||||||
|
if native_coderunner:
|
||||||
|
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
|
||||||
|
if native_search:
|
||||||
|
logger.warning("已启用代码执行工具,搜索工具将被忽略")
|
||||||
|
if tools:
|
||||||
|
logger.warning("已启用代码执行工具,函数工具将被忽略")
|
||||||
|
elif native_search:
|
||||||
|
tool_list = [types.Tool(google_search=types.GoogleSearch())]
|
||||||
|
if tools:
|
||||||
|
logger.warning("已启用搜索工具,函数工具将被忽略")
|
||||||
|
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
||||||
|
tool_list = [
|
||||||
|
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||||
|
]
|
||||||
|
return types.GenerateContentConfig(
|
||||||
|
system_instruction=system_instruction,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"),
|
||||||
|
top_p=payloads.get("top_p") or payloads.get("topP"),
|
||||||
|
top_k=payloads.get("top_k") or payloads.get("topK"),
|
||||||
|
frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"),
|
||||||
|
presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"),
|
||||||
|
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
|
||||||
|
response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"),
|
||||||
|
logprobs=payloads.get("logprobs"),
|
||||||
|
seed=payloads.get("seed"),
|
||||||
|
response_modalities=modalities,
|
||||||
|
tools=tool_list,
|
||||||
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||||
|
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||||
|
disable=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_conversation(self, payloads: Dict) -> List[types.Content]:
|
||||||
|
"""准备 Gemini SDK 的 Content 列表"""
|
||||||
|
|
||||||
|
def create_text_part(text: str) -> types.UserContent:
|
||||||
|
content_a = text if text else " "
|
||||||
|
if not text:
|
||||||
|
logger.warning("文本内容为空,已添加空格占位")
|
||||||
|
return types.UserContent(parts=[types.Part.from_text(text=content_a)])
|
||||||
|
|
||||||
|
def process_image_url(image_url_dict: dict) -> types.Part:
|
||||||
|
url = image_url_dict["url"]
|
||||||
|
mime_type = url.split(":")[1].split(";")[0]
|
||||||
|
image_bytes = base64.b64decode(url.split(",", 1)[1])
|
||||||
|
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
||||||
|
|
||||||
|
gemini_contents: List[types.Content] = []
|
||||||
|
native_tool_enabled = any(
|
||||||
|
[
|
||||||
|
self.provider_config.get("gm_native_coderunner", False),
|
||||||
|
self.provider_config.get("gm_native_search", False),
|
||||||
|
]
|
||||||
|
)
|
||||||
for message in payloads["messages"]:
|
for message in payloads["messages"]:
|
||||||
if message["role"] == "system":
|
role, content = message["role"], message.get("content")
|
||||||
system_instruction = message["content"]
|
|
||||||
break
|
|
||||||
|
|
||||||
google_genai_conversation = []
|
if role == "user":
|
||||||
for message in payloads["messages"]:
|
if isinstance(content, str):
|
||||||
if message["role"] == "user":
|
gemini_contents.append(create_text_part(content))
|
||||||
if isinstance(message["content"], str):
|
elif isinstance(content, list):
|
||||||
if not message["content"]:
|
parts = [
|
||||||
message["content"] = "<empty_content>"
|
types.Part.from_text(text=item["text"] or " ")
|
||||||
|
if item["type"] == "text"
|
||||||
|
else process_image_url(item["image_url"])
|
||||||
|
for item in content
|
||||||
|
]
|
||||||
|
gemini_contents.append(types.UserContent(parts=parts))
|
||||||
|
|
||||||
google_genai_conversation.append(
|
elif role == "assistant":
|
||||||
{"role": "user", "parts": [{"text": message["content"]}]}
|
if content:
|
||||||
|
gemini_contents.append(
|
||||||
|
types.ModelContent(parts=[types.Part.from_text(text=content)])
|
||||||
)
|
)
|
||||||
elif isinstance(message["content"], list):
|
elif "tool_calls" in message and not native_tool_enabled:
|
||||||
# images
|
gemini_contents.extend(
|
||||||
parts = []
|
[
|
||||||
for part in message["content"]:
|
types.ModelContent(
|
||||||
if part["type"] == "text":
|
parts=[
|
||||||
if not part["text"]:
|
types.Part.from_function_call(
|
||||||
part["text"] = "<empty_content>"
|
name=tool["function"]["name"],
|
||||||
parts.append({"text": part["text"]})
|
args=json.loads(tool["function"]["arguments"]),
|
||||||
elif part["type"] == "image_url":
|
)
|
||||||
parts.append(
|
]
|
||||||
{
|
|
||||||
"inline_data": {
|
|
||||||
"mime_type": "image/jpeg",
|
|
||||||
"data": part["image_url"]["url"].replace(
|
|
||||||
"data:image/jpeg;base64,", ""
|
|
||||||
), # base64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
for tool in message["tool_calls"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
||||||
|
if native_tool_enabled:
|
||||||
|
logger.warning(
|
||||||
|
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文"
|
||||||
|
)
|
||||||
|
gemini_contents.append(
|
||||||
|
types.ModelContent(parts=[types.Part.from_text(text=" ")])
|
||||||
|
)
|
||||||
|
|
||||||
elif message["role"] == "assistant":
|
elif role == "tool" and not native_tool_enabled:
|
||||||
if not message["content"]:
|
gemini_contents.append(
|
||||||
message["content"] = "<empty_content>"
|
types.UserContent(
|
||||||
google_genai_conversation.append(
|
parts=[
|
||||||
{"role": "model", "parts": [{"text": message["content"]}]}
|
types.Part.from_function_response(
|
||||||
|
name=message["tool_call_id"],
|
||||||
|
response={
|
||||||
|
"name": message["tool_call_id"],
|
||||||
|
"content": message["content"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
return gemini_contents
|
||||||
|
|
||||||
result = await self.client.generate_content(
|
@staticmethod
|
||||||
contents=google_genai_conversation,
|
def _process_content_parts(
|
||||||
model=self.get_model(),
|
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||||
system_instruction=system_instruction,
|
) -> MessageChain:
|
||||||
tools=tool,
|
"""处理内容部分并构建消息链"""
|
||||||
)
|
finish_reason = result.candidates[0].finish_reason
|
||||||
logger.debug(f"result: {result}")
|
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||||
|
|
||||||
if "candidates" not in result:
|
if finish_reason == types.FinishReason.SAFETY:
|
||||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||||
|
|
||||||
candidates = result["candidates"][0]["content"]["parts"]
|
if finish_reason in {
|
||||||
llm_response = LLMResponse("assistant")
|
types.FinishReason.PROHIBITED_CONTENT,
|
||||||
for candidate in candidates:
|
types.FinishReason.SPII,
|
||||||
if "text" in candidate:
|
types.FinishReason.BLOCKLIST,
|
||||||
llm_response.completion_text += candidate["text"]
|
}:
|
||||||
elif "functionCall" in candidate:
|
raise Exception("模型生成内容违反Gemini平台政策")
|
||||||
|
|
||||||
|
# 防止旧版本SDK不存在IMAGE_SAFETY
|
||||||
|
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
||||||
|
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
||||||
|
raise Exception("模型生成内容违反Gemini平台政策")
|
||||||
|
|
||||||
|
if not result_parts:
|
||||||
|
logger.debug(result.candidates)
|
||||||
|
raise Exception("API 返回的内容为空。")
|
||||||
|
|
||||||
|
chain = []
|
||||||
|
part: types.Part
|
||||||
|
|
||||||
|
# 暂时这样Fallback
|
||||||
|
if all(
|
||||||
|
part.inline_data and part.inline_data.mime_type.startswith("image/")
|
||||||
|
for part in result_parts
|
||||||
|
):
|
||||||
|
chain.append(Comp.Plain("这是图片"))
|
||||||
|
for part in result_parts:
|
||||||
|
if part.text:
|
||||||
|
chain.append(Comp.Plain(part.text))
|
||||||
|
elif part.function_call:
|
||||||
llm_response.role = "tool"
|
llm_response.role = "tool"
|
||||||
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
|
llm_response.tools_call_name.append(part.function_call.name)
|
||||||
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
|
llm_response.tools_call_args.append(part.function_call.args)
|
||||||
|
# gemini 返回的 function_call.id 可能为 None
|
||||||
|
llm_response.tools_call_ids.append(
|
||||||
|
part.function_call.id or part.function_call.name
|
||||||
|
)
|
||||||
|
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||||
|
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||||
|
return MessageChain(chain=chain)
|
||||||
|
|
||||||
llm_response.completion_text = llm_response.completion_text.strip()
|
async def _query(
|
||||||
|
self, payloads: dict, tools: FuncCall
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""非流式请求 Gemini API"""
|
||||||
|
system_instruction = next(
|
||||||
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
modalities = ["Text"]
|
||||||
|
if self.provider_config.get("gm_resp_image_modal", False):
|
||||||
|
modalities.append("Image")
|
||||||
|
|
||||||
|
conversation = self._prepare_conversation(payloads)
|
||||||
|
temperature=payloads.get("temperature", 0.7)
|
||||||
|
|
||||||
|
result: Optional[types.GenerateContentResponse] = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
config = await self._prepare_query_config(
|
||||||
|
payloads, tools, system_instruction, modalities, temperature
|
||||||
|
)
|
||||||
|
result = await self.client.models.generate_content(
|
||||||
|
model=self.get_model(),
|
||||||
|
contents=conversation,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||||
|
if temperature > 2:
|
||||||
|
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||||
|
temperature += 0.2
|
||||||
|
logger.warning(
|
||||||
|
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
if "Developer instruction is not enabled" in e.message:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||||
|
)
|
||||||
|
system_instruction = None
|
||||||
|
elif "Function calling is not enabled" in e.message:
|
||||||
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||||
|
tools = None
|
||||||
|
elif (
|
||||||
|
"Multi-modal output is not supported" in e.message
|
||||||
|
or "Model does not support the requested response modalities"
|
||||||
|
in e.message
|
||||||
|
or "only supports text output" in e.message
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||||
|
)
|
||||||
|
modalities = ["Text"]
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
async def _query_stream(
|
||||||
|
self, payloads: dict, tools: FuncCall
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""流式请求 Gemini API"""
|
||||||
|
system_instruction = next(
|
||||||
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = self._prepare_conversation(payloads)
|
||||||
|
|
||||||
|
result = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
config = await self._prepare_query_config(
|
||||||
|
payloads, tools, system_instruction
|
||||||
|
)
|
||||||
|
result = await self.client.models.generate_content_stream(
|
||||||
|
model=self.get_model(),
|
||||||
|
contents=conversation,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except APIError as e:
|
||||||
|
if "Developer instruction is not enabled" in e.message:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||||
|
)
|
||||||
|
system_instruction = None
|
||||||
|
elif "Function calling is not enabled" in e.message:
|
||||||
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||||
|
tools = None
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
async for chunk in result:
|
||||||
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||||
|
|
||||||
|
if chunk.candidates[0].content.parts and any(
|
||||||
|
part.function_call for part in chunk.candidates[0].content.parts
|
||||||
|
):
|
||||||
|
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||||
|
llm_response.result_chain = self._process_content_parts(
|
||||||
|
chunk, llm_response
|
||||||
|
)
|
||||||
|
yield llm_response
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.text:
|
||||||
|
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
if chunk.candidates[0].finish_reason:
|
||||||
|
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||||
|
if not chunk.candidates[0].content.parts:
|
||||||
|
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||||
|
else:
|
||||||
|
llm_response.result_chain = self._process_content_parts(
|
||||||
|
chunk, llm_response
|
||||||
|
)
|
||||||
|
yield llm_response
|
||||||
|
break
|
||||||
|
|
||||||
async def text_chat(
|
async def text_chat(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -186,10 +451,10 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts=[],
|
contexts=[],
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
context_query = []
|
|
||||||
context_query = [*contexts, new_record]
|
context_query = [*contexts, new_record]
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||||
@@ -198,85 +463,98 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
if "_no_save" in part:
|
if "_no_save" in part:
|
||||||
del part["_no_save"]
|
del part["_no_save"]
|
||||||
|
|
||||||
|
# tool calls result
|
||||||
|
if tool_calls_result:
|
||||||
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
model_config["model"] = self.get_model()
|
model_config["model"] = self.get_model()
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
llm_response = None
|
|
||||||
|
|
||||||
retry = 10
|
retry = 10
|
||||||
keys = self.api_keys.copy()
|
keys = self.api_keys.copy()
|
||||||
chosen_key = random.choice(keys)
|
|
||||||
|
|
||||||
for i in range(retry):
|
for _ in range(retry):
|
||||||
try:
|
try:
|
||||||
self.client.api_key = chosen_key
|
return await self._query(payloads, func_tool)
|
||||||
llm_response = await self._query(payloads, func_tool)
|
except APIError as e:
|
||||||
|
if await self._handle_api_error(e, keys):
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
except Exception as e:
|
|
||||||
if "maximum context length" in str(e):
|
|
||||||
retry_cnt = 20
|
|
||||||
while retry_cnt > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await self.pop_record(context_query)
|
|
||||||
llm_response = await self._query(payloads, func_tool)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if "maximum context length" in str(e):
|
|
||||||
retry_cnt -= 1
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
if retry_cnt == 0:
|
|
||||||
llm_response = LLMResponse(
|
|
||||||
"err", "err: 请尝试 /reset 重置会话"
|
|
||||||
)
|
|
||||||
elif "Function calling is not enabled" in str(e):
|
|
||||||
logger.info(
|
|
||||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
|
||||||
)
|
|
||||||
if "tools" in payloads:
|
|
||||||
del payloads["tools"]
|
|
||||||
llm_response = await self._query(payloads, None)
|
|
||||||
break
|
|
||||||
elif "429" in str(e) or "API key not valid" in str(e):
|
|
||||||
keys.remove(chosen_key)
|
|
||||||
if len(keys) > 0:
|
|
||||||
chosen_key = random.choice(keys)
|
|
||||||
logger.info(
|
|
||||||
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
|
|
||||||
)
|
|
||||||
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return llm_response
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = [],
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts=[],
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
|
context_query = [*contexts, new_record]
|
||||||
|
if system_prompt:
|
||||||
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
for part in context_query:
|
||||||
|
if "_no_save" in part:
|
||||||
|
del part["_no_save"]
|
||||||
|
|
||||||
|
# tool calls result
|
||||||
|
if tool_calls_result:
|
||||||
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
|
||||||
|
model_config = self.provider_config.get("model_config", {})
|
||||||
|
model_config["model"] = self.get_model()
|
||||||
|
|
||||||
|
payloads = {"messages": context_query, **model_config}
|
||||||
|
|
||||||
|
retry = 10
|
||||||
|
keys = self.api_keys.copy()
|
||||||
|
|
||||||
|
for _ in range(retry):
|
||||||
|
try:
|
||||||
|
async for response in self._query_stream(payloads, func_tool):
|
||||||
|
yield response
|
||||||
|
break
|
||||||
|
except APIError as e:
|
||||||
|
if await self._handle_api_error(e, keys):
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
async def get_models(self):
|
||||||
|
try:
|
||||||
|
models = await self.client.models.list()
|
||||||
|
return [
|
||||||
|
m.name.replace("models/", "")
|
||||||
|
for m in models
|
||||||
|
if "generateContent" in m.supported_actions
|
||||||
|
]
|
||||||
|
except APIError as e:
|
||||||
|
raise Exception(f"获取模型列表失败: {e.message}")
|
||||||
|
|
||||||
def get_current_key(self) -> str:
|
def get_current_key(self) -> str:
|
||||||
return self.client.api_key
|
return self.chosen_api_key
|
||||||
|
|
||||||
def get_keys(self) -> List[str]:
|
def get_keys(self) -> List[str]:
|
||||||
return self.api_keys
|
return self.api_keys
|
||||||
|
|
||||||
def set_key(self, key):
|
def set_key(self, key):
|
||||||
self.client.api_key = key
|
self.chosen_api_key = key
|
||||||
|
self._init_client()
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||||
"""
|
"""
|
||||||
组装上下文。
|
组装上下文。
|
||||||
"""
|
"""
|
||||||
if image_urls:
|
if image_urls:
|
||||||
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
|
user_content = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||||
|
}
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
if image_url.startswith("http"):
|
if image_url.startswith("http"):
|
||||||
image_path = await download_image_by_url(image_url)
|
image_path = await download_image_by_url(image_url)
|
||||||
@@ -308,5 +586,4 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
await self.client.client.close()
|
|
||||||
logger.info("Google GenAI 适配器已终止。")
|
logger.info("Google GenAI 适配器已终止。")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import uuid
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from ..provider import TTSProvider
|
from ..provider import TTSProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from typing import List
|
from typing import List
|
||||||
from .. import Provider
|
from .. import Provider
|
||||||
from ..entites import LLMResponse
|
from ..entities import LLMResponse
|
||||||
from ..func_tool_manager import FuncCall
|
from ..func_tool_manager import FuncCall
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
@@ -95,6 +95,33 @@ class LLMTunerModelLoader(Provider):
|
|||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=...,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=...,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
# 调用 text_chat 模拟流式
|
||||||
|
llm_response = await self.text_chat(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
)
|
||||||
|
llm_response.is_chunk = True
|
||||||
|
yield llm_response
|
||||||
|
llm_response.is_chunk = False
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
async def get_current_key(self):
|
async def get_current_key(self):
|
||||||
return "none"
|
return "none"
|
||||||
|
|
||||||
|
|||||||
@@ -2,19 +2,26 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
|
|
||||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||||
|
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.api.provider import Provider, Personality
|
from astrbot.api.provider import Provider, Personality
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from typing import List
|
from typing import List, AsyncGenerator
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.provider.entites import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter(
|
@register_provider_adapter(
|
||||||
@@ -80,7 +87,11 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
|
|
||||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||||
if tools:
|
if tools:
|
||||||
tool_list = tools.get_func_desc_openai_style()
|
model = payloads.get("model", "").lower()
|
||||||
|
omit_empty_param_field = "gemini" in model
|
||||||
|
tool_list = tools.get_func_desc_openai_style(
|
||||||
|
omit_empty_parameter_field=omit_empty_param_field
|
||||||
|
)
|
||||||
if tool_list:
|
if tool_list:
|
||||||
payloads["tools"] = tool_list
|
payloads["tools"] = tool_list
|
||||||
|
|
||||||
@@ -105,30 +116,93 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
|
|
||||||
logger.debug(f"completion: {completion}")
|
logger.debug(f"completion: {completion}")
|
||||||
|
|
||||||
|
llm_response = await self.parse_openai_completion(completion, tools)
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
async def _query_stream(
|
||||||
|
self, payloads: dict, tools: FuncCall
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""流式查询API,逐步返回结果"""
|
||||||
|
if tools:
|
||||||
|
model = payloads.get("model", "").lower()
|
||||||
|
omit_empty_param_field = "gemini" in model
|
||||||
|
tool_list = tools.get_func_desc_openai_style(
|
||||||
|
omit_empty_parameter_field=omit_empty_param_field
|
||||||
|
)
|
||||||
|
if tool_list:
|
||||||
|
payloads["tools"] = tool_list
|
||||||
|
|
||||||
|
# 不在默认参数中的参数放在 extra_body 中
|
||||||
|
extra_body = {}
|
||||||
|
to_del = []
|
||||||
|
for key in payloads.keys():
|
||||||
|
if key not in self.default_params:
|
||||||
|
extra_body[key] = payloads[key]
|
||||||
|
to_del.append(key)
|
||||||
|
for key in to_del:
|
||||||
|
del payloads[key]
|
||||||
|
|
||||||
|
stream = await self.client.chat.completions.create(
|
||||||
|
**payloads, stream=True, extra_body=extra_body
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||||
|
|
||||||
|
state = ChatCompletionStreamState()
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
try:
|
||||||
|
state.handle_chunk(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Saving chunk state error: " + str(e))
|
||||||
|
if len(chunk.choices) == 0:
|
||||||
|
continue
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
# 处理文本内容
|
||||||
|
if delta.content:
|
||||||
|
completion_text = delta.content
|
||||||
|
llm_response.result_chain = MessageChain(
|
||||||
|
chain=[Comp.Plain(completion_text)]
|
||||||
|
)
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
final_completion = state.get_final_completion()
|
||||||
|
llm_response = await self.parse_openai_completion(final_completion, tools)
|
||||||
|
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
async def parse_openai_completion(
|
||||||
|
self, completion: ChatCompletion, tools: FuncCall
|
||||||
|
):
|
||||||
|
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
|
||||||
if len(completion.choices) == 0:
|
if len(completion.choices) == 0:
|
||||||
raise Exception("API 返回的 completion 为空。")
|
raise Exception("API 返回的 completion 为空。")
|
||||||
choice = completion.choices[0]
|
choice = completion.choices[0]
|
||||||
|
|
||||||
llm_response = LLMResponse("assistant")
|
|
||||||
|
|
||||||
if choice.message.content:
|
if choice.message.content:
|
||||||
# text completion
|
# text completion
|
||||||
completion_text = str(choice.message.content).strip()
|
completion_text = str(choice.message.content).strip()
|
||||||
llm_response.completion_text = completion_text
|
llm_response.result_chain = MessageChain().message(completion_text)
|
||||||
|
|
||||||
if choice.message.tool_calls:
|
if choice.message.tool_calls:
|
||||||
# tools call (function calling)
|
# tools call (function calling)
|
||||||
args_ls = []
|
args_ls = []
|
||||||
func_name_ls = []
|
func_name_ls = []
|
||||||
|
tool_call_ids = []
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
for tool in tools.func_list:
|
for tool in tools.func_list:
|
||||||
if tool.name == tool_call.function.name:
|
if tool.name == tool_call.function.name:
|
||||||
args = json.loads(tool_call.function.arguments)
|
args = json.loads(tool_call.function.arguments)
|
||||||
args_ls.append(args)
|
args_ls.append(args)
|
||||||
func_name_ls.append(tool_call.function.name)
|
func_name_ls.append(tool_call.function.name)
|
||||||
|
tool_call_ids.append(tool_call.id)
|
||||||
llm_response.role = "tool"
|
llm_response.role = "tool"
|
||||||
llm_response.tools_call_args = args_ls
|
llm_response.tools_call_args = args_ls
|
||||||
llm_response.tools_call_name = func_name_ls
|
llm_response.tools_call_name = func_name_ls
|
||||||
|
llm_response.tools_call_ids = tool_call_ids
|
||||||
|
|
||||||
if choice.finish_reason == "content_filter":
|
if choice.finish_reason == "content_filter":
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -143,7 +217,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
async def text_chat(
|
async def _prepare_chat_payload(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
@@ -151,8 +225,10 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
func_tool: FuncCall = None,
|
func_tool: FuncCall = None,
|
||||||
contexts=[],
|
contexts=[],
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> tuple:
|
||||||
|
"""准备聊天所需的有效载荷和上下文"""
|
||||||
new_record = await self.assemble_context(prompt, image_urls)
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
context_query = [*contexts, new_record]
|
context_query = [*contexts, new_record]
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -162,84 +238,235 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
if "_no_save" in part:
|
if "_no_save" in part:
|
||||||
del part["_no_save"]
|
del part["_no_save"]
|
||||||
|
|
||||||
|
# tool calls result
|
||||||
|
if tool_calls_result:
|
||||||
|
context_query.extend(tool_calls_result.to_openai_messages())
|
||||||
|
|
||||||
model_config = self.provider_config.get("model_config", {})
|
model_config = self.provider_config.get("model_config", {})
|
||||||
model_config["model"] = self.get_model()
|
model_config["model"] = self.get_model()
|
||||||
|
|
||||||
payloads = {"messages": context_query, **model_config}
|
payloads = {"messages": context_query, **model_config}
|
||||||
llm_response = None
|
|
||||||
try:
|
return payloads, context_query, func_tool
|
||||||
llm_response = await self._query(payloads, func_tool)
|
|
||||||
except UnprocessableEntityError as e:
|
async def _handle_api_error(
|
||||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
self,
|
||||||
|
e: Exception,
|
||||||
|
payloads: dict,
|
||||||
|
context_query: list,
|
||||||
|
func_tool: FuncCall,
|
||||||
|
chosen_key: str,
|
||||||
|
available_api_keys: List[str],
|
||||||
|
retry_cnt: int,
|
||||||
|
max_retries: int,
|
||||||
|
) -> tuple:
|
||||||
|
"""处理API错误并尝试恢复"""
|
||||||
|
if "429" in str(e):
|
||||||
|
logger.warning(
|
||||||
|
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
||||||
|
)
|
||||||
|
# 最后一次不等待
|
||||||
|
if retry_cnt < max_retries - 1:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
available_api_keys.remove(chosen_key)
|
||||||
|
if len(available_api_keys) > 0:
|
||||||
|
chosen_key = random.choice(available_api_keys)
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
chosen_key,
|
||||||
|
available_api_keys,
|
||||||
|
payloads,
|
||||||
|
context_query,
|
||||||
|
func_tool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
elif "maximum context length" in str(e):
|
||||||
|
logger.warning(
|
||||||
|
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||||
|
)
|
||||||
|
await self.pop_record(context_query)
|
||||||
|
payloads["messages"] = context_query
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
chosen_key,
|
||||||
|
available_api_keys,
|
||||||
|
payloads,
|
||||||
|
context_query,
|
||||||
|
func_tool,
|
||||||
|
)
|
||||||
|
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||||
# 尝试删除所有 image
|
# 尝试删除所有 image
|
||||||
new_contexts = await self._remove_image_from_context(context_query)
|
new_contexts = await self._remove_image_from_context(context_query)
|
||||||
payloads["messages"] = new_contexts
|
payloads["messages"] = new_contexts
|
||||||
context_query = new_contexts
|
context_query = new_contexts
|
||||||
llm_response = await self._query(payloads, func_tool)
|
return (
|
||||||
except Exception as e:
|
False,
|
||||||
if "maximum context length" in str(e):
|
chosen_key,
|
||||||
# 重试 10 次
|
available_api_keys,
|
||||||
retry_cnt = 20
|
payloads,
|
||||||
while retry_cnt > 0:
|
context_query,
|
||||||
logger.warning(
|
func_tool,
|
||||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
)
|
||||||
|
elif (
|
||||||
|
"Function calling is not enabled" in str(e)
|
||||||
|
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
||||||
|
or ("function" in str(e).lower() and "support" in str(e).lower())
|
||||||
|
):
|
||||||
|
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||||
|
logger.info(
|
||||||
|
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||||
|
)
|
||||||
|
if "tools" in payloads:
|
||||||
|
del payloads["tools"]
|
||||||
|
return False, chosen_key, available_api_keys, payloads, context_query, None
|
||||||
|
else:
|
||||||
|
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||||
|
|
||||||
|
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||||
|
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||||
|
|
||||||
|
if "Connection error." in str(e):
|
||||||
|
proxy = os.environ.get("http_proxy", None)
|
||||||
|
if proxy:
|
||||||
|
logger.error(
|
||||||
|
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
await self.pop_record(context_query)
|
raise e
|
||||||
llm_response = await self._query(payloads, func_tool)
|
|
||||||
break
|
async def text_chat(
|
||||||
except Exception as e:
|
self,
|
||||||
if "maximum context length" in str(e):
|
prompt: str,
|
||||||
retry_cnt -= 1
|
session_id: str = None,
|
||||||
else:
|
image_urls: List[str] = [],
|
||||||
raise e
|
func_tool: FuncCall = None,
|
||||||
if retry_cnt == 0:
|
contexts=[],
|
||||||
llm_response = LLMResponse(
|
system_prompt=None,
|
||||||
"err", "err: 请尝试 /reset 清除会话记录。"
|
tool_calls_result=None,
|
||||||
)
|
**kwargs,
|
||||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
) -> LLMResponse:
|
||||||
|
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||||
|
prompt,
|
||||||
|
session_id,
|
||||||
|
image_urls,
|
||||||
|
func_tool,
|
||||||
|
contexts,
|
||||||
|
system_prompt,
|
||||||
|
tool_calls_result,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_response = None
|
||||||
|
max_retries = 10
|
||||||
|
available_api_keys = self.api_keys.copy()
|
||||||
|
chosen_key = random.choice(available_api_keys)
|
||||||
|
|
||||||
|
e = None
|
||||||
|
retry_cnt = 0
|
||||||
|
for retry_cnt in range(max_retries):
|
||||||
|
try:
|
||||||
|
self.client.api_key = chosen_key
|
||||||
|
llm_response = await self._query(payloads, func_tool)
|
||||||
|
break
|
||||||
|
except UnprocessableEntityError as e:
|
||||||
|
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||||
# 尝试删除所有 image
|
# 尝试删除所有 image
|
||||||
new_contexts = await self._remove_image_from_context(context_query)
|
new_contexts = await self._remove_image_from_context(context_query)
|
||||||
payloads["messages"] = new_contexts
|
payloads["messages"] = new_contexts
|
||||||
llm_response = await self._query(payloads, func_tool)
|
context_query = new_contexts
|
||||||
|
except Exception as e:
|
||||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
(
|
||||||
elif (
|
success,
|
||||||
"does not support Function Calling" in str(e)
|
chosen_key,
|
||||||
or "does not support tools" in str(e)
|
available_api_keys,
|
||||||
or "Function call is not supported" in str(e)
|
payloads,
|
||||||
or "Function calling is not enabled" in str(e)
|
context_query,
|
||||||
or "Tool calling is not supported" in str(e)
|
func_tool,
|
||||||
or "No endpoints found that support tool use" in str(e)
|
) = await self._handle_api_error(
|
||||||
or "model does not support function calling" in str(e)
|
e,
|
||||||
or ("tool" in str(e) and "support" in str(e).lower())
|
payloads,
|
||||||
or ("function" in str(e) and "support" in str(e).lower())
|
context_query,
|
||||||
):
|
func_tool,
|
||||||
logger.info(
|
chosen_key,
|
||||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
available_api_keys,
|
||||||
|
retry_cnt,
|
||||||
|
max_retries,
|
||||||
)
|
)
|
||||||
if "tools" in payloads:
|
if success:
|
||||||
del payloads["tools"]
|
break
|
||||||
llm_response = await self._query(payloads, None)
|
|
||||||
else:
|
|
||||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
|
||||||
|
|
||||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
|
||||||
logger.error(
|
|
||||||
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "Connection error." in str(e):
|
|
||||||
proxy = os.environ.get("http_proxy", None)
|
|
||||||
if proxy:
|
|
||||||
logger.error(
|
|
||||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
|
||||||
)
|
|
||||||
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
if retry_cnt == max_retries - 1:
|
||||||
|
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||||
|
raise e
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = [],
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts=[],
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""流式对话,与服务商交互并逐步返回结果"""
|
||||||
|
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||||
|
prompt,
|
||||||
|
session_id,
|
||||||
|
image_urls,
|
||||||
|
func_tool,
|
||||||
|
contexts,
|
||||||
|
system_prompt,
|
||||||
|
tool_calls_result,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_retries = 10
|
||||||
|
available_api_keys = self.api_keys.copy()
|
||||||
|
chosen_key = random.choice(available_api_keys)
|
||||||
|
|
||||||
|
e = None
|
||||||
|
retry_cnt = 0
|
||||||
|
for retry_cnt in range(max_retries):
|
||||||
|
try:
|
||||||
|
self.client.api_key = chosen_key
|
||||||
|
async for response in self._query_stream(payloads, func_tool):
|
||||||
|
yield response
|
||||||
|
break
|
||||||
|
except UnprocessableEntityError as e:
|
||||||
|
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||||
|
# 尝试删除所有 image
|
||||||
|
new_contexts = await self._remove_image_from_context(context_query)
|
||||||
|
payloads["messages"] = new_contexts
|
||||||
|
context_query = new_contexts
|
||||||
|
except Exception as e:
|
||||||
|
(
|
||||||
|
success,
|
||||||
|
chosen_key,
|
||||||
|
available_api_keys,
|
||||||
|
payloads,
|
||||||
|
context_query,
|
||||||
|
func_tool,
|
||||||
|
) = await self._handle_api_error(
|
||||||
|
e,
|
||||||
|
payloads,
|
||||||
|
context_query,
|
||||||
|
func_tool,
|
||||||
|
chosen_key,
|
||||||
|
available_api_keys,
|
||||||
|
retry_cnt,
|
||||||
|
max_retries,
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
break
|
||||||
|
|
||||||
|
if retry_cnt == max_retries - 1:
|
||||||
|
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||||
|
raise e
|
||||||
|
|
||||||
async def _remove_image_from_context(self, contexts: List):
|
async def _remove_image_from_context(self, contexts: List):
|
||||||
"""
|
"""
|
||||||
从上下文中删除所有带有 image 的记录
|
从上下文中删除所有带有 image 的记录
|
||||||
@@ -275,12 +502,10 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
def set_key(self, key):
|
def set_key(self, key):
|
||||||
self.client.api_key = key
|
self.client.api_key = key
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
|
||||||
"""
|
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||||
组装上下文。
|
|
||||||
"""
|
|
||||||
if image_urls:
|
if image_urls:
|
||||||
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
|
user_content = {"role": "user", "content": [{"type": "text", "text": text if text else "[图片]"}]}
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
if image_url.startswith("http"):
|
if image_url.startswith("http"):
|
||||||
image_path = await download_image_by_url(image_url)
|
image_path = await download_image_by_url(image_url)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from openai import AsyncOpenAI, NOT_GIVEN
|
from openai import AsyncOpenAI, NOT_GIVEN
|
||||||
from ..provider import TTSProvider
|
from ..provider import TTSProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
|
||||||
|
|
||||||
@@ -18,10 +18,14 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
|||||||
self.chosen_api_key = provider_config.get("api_key", "")
|
self.chosen_api_key = provider_config.get("api_key", "")
|
||||||
self.voice = provider_config.get("openai-tts-voice", "alloy")
|
self.voice = provider_config.get("openai-tts-voice", "alloy")
|
||||||
|
|
||||||
|
timeout = provider_config.get("timeout", NOT_GIVEN)
|
||||||
|
if isinstance(timeout, str):
|
||||||
|
timeout = int(timeout)
|
||||||
|
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(
|
||||||
api_key=self.chosen_api_key,
|
api_key=self.chosen_api_key,
|
||||||
base_url=provider_config.get("api_base", None),
|
base_url=provider_config.get("api_base", None),
|
||||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.set_model(provider_config.get("model", None))
|
self.set_model(provider_config.get("model", None))
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import re
|
|||||||
from funasr_onnx import SenseVoiceSmall
|
from funasr_onnx import SenseVoiceSmall
|
||||||
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
||||||
from ..provider import STTProvider
|
from ..provider import STTProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from astrbot.core.utils.io import download_file
|
from astrbot.core.utils.io import download_file
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
@@ -48,14 +48,6 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
return os.path.join("data", "temp", f"{timestamp}")
|
return os.path.join("data", "temp", f"{timestamp}")
|
||||||
|
|
||||||
async def _convert_audio(self, path: str) -> str:
|
|
||||||
from pyffmpeg import FFmpeg
|
|
||||||
|
|
||||||
filename = await self.get_timestamped_path() + ".mp3"
|
|
||||||
ff = FFmpeg()
|
|
||||||
output_path = ff.convert(path, os.path.join('data","temp', filename))
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
async def _is_silk_file(self, file_path):
|
async def _is_silk_file(self, file_path):
|
||||||
silk_header = b"SILK"
|
silk_header = b"SILK"
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import uuid
|
|||||||
import os
|
import os
|
||||||
from openai import AsyncOpenAI, NOT_GIVEN
|
from openai import AsyncOpenAI, NOT_GIVEN
|
||||||
from ..provider import STTProvider
|
from ..provider import STTProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from astrbot.core.utils.io import download_file
|
from astrbot.core.utils.io import download_file
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
@@ -31,14 +31,6 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
|||||||
|
|
||||||
self.set_model(provider_config.get("model", None))
|
self.set_model(provider_config.get("model", None))
|
||||||
|
|
||||||
async def _convert_audio(self, path: str) -> str:
|
|
||||||
from pyffmpeg import FFmpeg
|
|
||||||
|
|
||||||
filename = str(uuid.uuid4()) + ".mp3"
|
|
||||||
ff = FFmpeg()
|
|
||||||
output_path = ff.convert(path, os.path.join("data/temp", filename))
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
async def _is_silk_file(self, file_path):
|
async def _is_silk_file(self, file_path):
|
||||||
silk_header = b"SILK"
|
silk_header = b"SILK"
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import whisper
|
import whisper
|
||||||
from ..provider import STTProvider
|
from ..provider import STTProvider
|
||||||
from ..entites import ProviderType
|
from ..entities import ProviderType
|
||||||
from astrbot.core.utils.io import download_file
|
from astrbot.core.utils.io import download_file
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
@@ -33,14 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
)
|
)
|
||||||
logger.info("Whisper 模型加载完成。")
|
logger.info("Whisper 模型加载完成。")
|
||||||
|
|
||||||
async def _convert_audio(self, path: str) -> str:
|
|
||||||
from pyffmpeg import FFmpeg
|
|
||||||
|
|
||||||
filename = str(uuid.uuid4()) + ".mp3"
|
|
||||||
ff = FFmpeg()
|
|
||||||
output_path = ff.convert(path, os.path.join("data/temp", filename))
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
async def _is_silk_file(self, file_path):
|
async def _is_silk_file(self, file_path):
|
||||||
silk_header = b"SILK"
|
silk_header = b"SILK"
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from astrbot import logger
|
|||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from typing import List
|
from typing import List
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.provider.entites import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from .openai_source import ProviderOpenAIOfficial
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ from .context import Context
|
|||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||||
from astrbot.core import html_renderer
|
from astrbot.core import html_renderer
|
||||||
|
from astrbot.core.star.star_tools import StarTools
|
||||||
|
|
||||||
|
|
||||||
class Star(CommandParserMixin):
|
class Star(CommandParserMixin):
|
||||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||||
|
|
||||||
def __init__(self, context: Context):
|
def __init__(self, context: Context):
|
||||||
|
StarTools.initialize(context)
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||||
@@ -27,4 +29,4 @@ class Star(CommandParserMixin):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"]
|
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||||
|
|||||||
@@ -183,11 +183,15 @@ class Context:
|
|||||||
获取指定类型的平台适配器。
|
获取指定类型的平台适配器。
|
||||||
"""
|
"""
|
||||||
for platform in self.platform_manager.platform_insts:
|
for platform in self.platform_manager.platform_insts:
|
||||||
|
name = platform.meta().name
|
||||||
if isinstance(platform_type, str):
|
if isinstance(platform_type, str):
|
||||||
if platform.meta().name == platform_type:
|
if name == platform_type:
|
||||||
return platform
|
return platform
|
||||||
else:
|
else:
|
||||||
if platform.meta().name == ADAPTER_NAME_2_TYPE[platform_type]:
|
if (
|
||||||
|
name in ADAPTER_NAME_2_TYPE
|
||||||
|
and ADAPTER_NAME_2_TYPE[name] & platform_type
|
||||||
|
):
|
||||||
return platform
|
return platform
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
|
|||||||
0
astrbot/core/star/filter/command.py
Normal file → Executable file
0
astrbot/core/star/filter/command.py
Normal file → Executable file
0
astrbot/core/star/filter/command_group.py
Normal file → Executable file
0
astrbot/core/star/filter/command_group.py
Normal file → Executable file
@@ -15,7 +15,6 @@ from ..filter.regex import RegexFilter
|
|||||||
from typing import Awaitable
|
from typing import Awaitable
|
||||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
from astrbot.core import logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||||
@@ -359,9 +358,9 @@ def register_llm_tool(name: str = None):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||||
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
|
llm_tools.add_func(
|
||||||
|
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
)
|
||||||
return awaitable
|
return awaitable
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ star_map: Dict[str, StarMetadata] = {}
|
|||||||
class StarMetadata:
|
class StarMetadata:
|
||||||
"""
|
"""
|
||||||
插件的元数据。
|
插件的元数据。
|
||||||
|
|
||||||
|
当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -45,5 +47,29 @@ class StarMetadata:
|
|||||||
star_handler_full_names: List[str] = field(default_factory=list)
|
star_handler_full_names: List[str] = field(default_factory=list)
|
||||||
"""注册的 Handler 的全名列表"""
|
"""注册的 Handler 的全名列表"""
|
||||||
|
|
||||||
|
supported_platforms: Dict[str, bool] = field(default_factory=dict)
|
||||||
|
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||||
|
|
||||||
|
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||||
|
"""更新插件支持的平台列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项
|
||||||
|
"""
|
||||||
|
if not plugin_enable_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 清空之前的配置
|
||||||
|
self.supported_platforms.clear()
|
||||||
|
|
||||||
|
# 遍历所有平台配置
|
||||||
|
for platform_id, plugins in plugin_enable_config.items():
|
||||||
|
# 检查该插件在当前平台的配置
|
||||||
|
if self.name in plugins:
|
||||||
|
self.supported_platforms[platform_id] = plugins[self.name]
|
||||||
|
else:
|
||||||
|
# 如果没有明确配置,默认为启用
|
||||||
|
self.supported_platforms[platform_id] = True
|
||||||
|
|||||||
@@ -30,21 +30,36 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
print(handler.handler_full_name)
|
print(handler.handler_full_name)
|
||||||
|
|
||||||
def get_handlers_by_event_type(
|
def get_handlers_by_event_type(
|
||||||
self, event_type: EventType, only_activated=True
|
self, event_type: EventType, only_activated=True, platform_id=None
|
||||||
) -> List[StarHandlerMetadata]:
|
) -> List[StarHandlerMetadata]:
|
||||||
"""通过事件类型获取 Handler"""
|
"""通过事件类型获取 Handler
|
||||||
handlers = [
|
|
||||||
handler
|
Args:
|
||||||
for _, handler in self._handlers
|
event_type: 事件类型
|
||||||
if handler.event_type == event_type
|
only_activated: 是否只返回已激活的插件的处理器
|
||||||
and (
|
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||||
not only_activated
|
|
||||||
or (
|
Returns:
|
||||||
star_map[handler.handler_module_path]
|
List[StarHandlerMetadata]: 处理器列表
|
||||||
and star_map[handler.handler_module_path].activated
|
"""
|
||||||
)
|
handlers = []
|
||||||
)
|
for _, handler in self._handlers:
|
||||||
]
|
if handler.event_type != event_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 只激活的插件处理器
|
||||||
|
if only_activated:
|
||||||
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
|
if not (plugin and plugin.activated):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 平台兼容性过滤
|
||||||
|
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||||
|
if not handler.is_enabled_for_platform(platform_id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||||
@@ -139,3 +154,32 @@ class StarHandlerMetadata:
|
|||||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||||
"priority", 0
|
"priority", 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_enabled_for_platform(self, platform_id: str) -> bool:
|
||||||
|
"""检查插件是否在指定平台启用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否启用,True表示启用,False表示禁用
|
||||||
|
"""
|
||||||
|
plugin = star_map.get(self.handler_module_path)
|
||||||
|
|
||||||
|
# 如果插件元数据不存在,默认允许执行
|
||||||
|
if not plugin or not plugin.name:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 先检查插件是否被激活
|
||||||
|
if not plugin.activated:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
|
||||||
|
if (
|
||||||
|
hasattr(plugin, "supported_platforms")
|
||||||
|
and platform_id in plugin.supported_platforms
|
||||||
|
):
|
||||||
|
return plugin.supported_platforms[platform_id]
|
||||||
|
|
||||||
|
# 如果没有缓存数据,默认允许执行
|
||||||
|
return True
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
插件的重载、启停、安装、卸载等操作。
|
||||||
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
@@ -24,7 +28,7 @@ from .filter.permission import PermissionTypeFilter, PermissionType
|
|||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
def __init__(self, context: Context, config: AstrBotConfig):
|
def __init__(self, context: Context, config: AstrBotConfig):
|
||||||
self.updator = PluginUpdator(config["plugin_repo_mirror"])
|
self.updator = PluginUpdator()
|
||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
self.context._star_manager = self
|
self.context._star_manager = self
|
||||||
@@ -75,7 +79,7 @@ class PluginManager:
|
|||||||
elif os.path.exists(os.path.join(path, d, d + ".py")):
|
elif os.path.exists(os.path.join(path, d, d + ".py")):
|
||||||
module_str = d
|
module_str = d
|
||||||
else:
|
else:
|
||||||
print(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。")
|
logger.info(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。")
|
||||||
continue
|
continue
|
||||||
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
|
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
|
||||||
os.path.join(path, d, d + ".py")
|
os.path.join(path, d, d + ".py")
|
||||||
@@ -162,9 +166,71 @@ class PluginManager:
|
|||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
async def reload(self, specified_plugin_name=None):
|
def _get_plugin_related_modules(
|
||||||
"""扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件"""
|
self, plugin_root_dir: str, is_reserved: bool = False
|
||||||
|
) -> list[str]:
|
||||||
|
"""获取与指定插件相关的所有已加载模块名
|
||||||
|
|
||||||
|
根据插件根目录名和是否为保留插件,从 sys.modules 中筛选出相关的模块名
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_root_dir: 插件根目录名
|
||||||
|
is_reserved: 是否是保留插件,影响模块路径前缀
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: 与该插件相关的模块名列表
|
||||||
|
"""
|
||||||
|
prefix = "packages." if is_reserved else "data.plugins."
|
||||||
|
return [
|
||||||
|
key
|
||||||
|
for key in list(sys.modules.keys())
|
||||||
|
if key.startswith(f"{prefix}{plugin_root_dir}")
|
||||||
|
]
|
||||||
|
|
||||||
|
def _purge_modules(
|
||||||
|
self,
|
||||||
|
module_patterns: list[str] = None,
|
||||||
|
root_dir_name: str = None,
|
||||||
|
is_reserved: bool = False,
|
||||||
|
):
|
||||||
|
"""从 sys.modules 中移除指定的模块
|
||||||
|
|
||||||
|
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
||||||
|
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||||
|
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||||
|
"""
|
||||||
|
if module_patterns:
|
||||||
|
for pattern in module_patterns:
|
||||||
|
for key in list(sys.modules.keys()):
|
||||||
|
if key.startswith(pattern):
|
||||||
|
del sys.modules[key]
|
||||||
|
logger.debug(f"删除模块 {key}")
|
||||||
|
|
||||||
|
if root_dir_name:
|
||||||
|
for module_name in self._get_plugin_related_modules(
|
||||||
|
root_dir_name, is_reserved
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
logger.debug(f"删除模块 {module_name}")
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"模块 {module_name} 未载入")
|
||||||
|
|
||||||
|
async def reload(self, specified_plugin_name=None):
|
||||||
|
"""重新加载插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
specified_plugin_name (str, optional): 要重载的特定插件名称。
|
||||||
|
如果为 None,则重载所有插件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: 返回 load() 方法的结果,包含 (success, error_message)
|
||||||
|
- success (bool): 重载是否成功
|
||||||
|
- error_message (str|None): 错误信息,成功时为 None
|
||||||
|
"""
|
||||||
specified_module_path = None
|
specified_module_path = None
|
||||||
if specified_plugin_name:
|
if specified_plugin_name:
|
||||||
for smd in star_registry:
|
for smd in star_registry:
|
||||||
@@ -184,12 +250,11 @@ class PluginManager:
|
|||||||
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self._unbind_plugin(smd.name, smd.module_path)
|
||||||
|
|
||||||
star_handlers_registry.clear()
|
star_handlers_registry.clear()
|
||||||
star_map.clear()
|
star_map.clear()
|
||||||
star_registry.clear()
|
star_registry.clear()
|
||||||
for key in list(sys.modules.keys()):
|
|
||||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
|
||||||
del sys.modules[key]
|
|
||||||
else:
|
else:
|
||||||
# 只重载指定插件
|
# 只重载指定插件
|
||||||
smd = star_map.get(specified_module_path)
|
smd = star_map.get(specified_module_path)
|
||||||
@@ -203,10 +268,50 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self._unbind_plugin(smd.name, specified_module_path)
|
await self._unbind_plugin(smd.name, specified_module_path)
|
||||||
try:
|
|
||||||
del sys.modules[specified_module_path]
|
result = await self.load(specified_module_path)
|
||||||
except KeyError:
|
|
||||||
logger.warning(f"模块 {specified_module_path} 未载入")
|
# 更新所有插件的平台兼容性
|
||||||
|
await self.update_all_platform_compatibility()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def update_all_platform_compatibility(self):
|
||||||
|
"""更新所有插件的平台兼容性设置"""
|
||||||
|
# 获取最新的平台插件启用配置
|
||||||
|
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||||
|
"plugin_enable", {}
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 遍历所有插件,更新平台兼容性
|
||||||
|
for plugin in self.context.get_all_stars():
|
||||||
|
plugin.update_platform_compatibility(plugin_enable_config)
|
||||||
|
logger.debug(
|
||||||
|
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def load(self, specified_module_path=None, specified_dir_name=None):
|
||||||
|
"""载入插件。
|
||||||
|
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
specified_module_path (str, optional): 指定要加载的插件模块路径。例如: "data.plugins.my_plugin.main"
|
||||||
|
specified_dir_name (str, optional): 指定要加载的插件目录名。例如: "my_plugin"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (success, error_message)
|
||||||
|
- success (bool): 是否全部加载成功
|
||||||
|
- error_message (str|None): 错误信息,成功时为 None
|
||||||
|
"""
|
||||||
|
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||||
|
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||||
|
|
||||||
|
alter_cmd = sp.get("alter_cmd", {})
|
||||||
|
|
||||||
plugin_modules = self._get_plugin_modules()
|
plugin_modules = self._get_plugin_modules()
|
||||||
if plugin_modules is None:
|
if plugin_modules is None:
|
||||||
@@ -214,11 +319,6 @@ class PluginManager:
|
|||||||
|
|
||||||
fail_rec = ""
|
fail_rec = ""
|
||||||
|
|
||||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
|
||||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
|
||||||
|
|
||||||
alter_cmd = sp.get("alter_cmd", {})
|
|
||||||
|
|
||||||
# 导入插件模块,并尝试实例化插件类
|
# 导入插件模块,并尝试实例化插件类
|
||||||
for plugin_module in plugin_modules:
|
for plugin_module in plugin_modules:
|
||||||
try:
|
try:
|
||||||
@@ -232,8 +332,11 @@ class PluginManager:
|
|||||||
path = "data.plugins." if not reserved else "packages."
|
path = "data.plugins." if not reserved else "packages."
|
||||||
path += root_dir_name + "." + module_str
|
path += root_dir_name + "." + module_str
|
||||||
|
|
||||||
|
# 检查是否需要载入指定的插件
|
||||||
if specified_module_path and path != specified_module_path:
|
if specified_module_path and path != specified_module_path:
|
||||||
continue
|
continue
|
||||||
|
if specified_dir_name and root_dir_name != specified_dir_name:
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(f"正在载入插件 {root_dir_name} ...")
|
logger.info(f"正在载入插件 {root_dir_name} ...")
|
||||||
|
|
||||||
@@ -287,23 +390,35 @@ class PluginManager:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if plugin_config:
|
if path not in inactivated_plugins:
|
||||||
metadata.config = plugin_config
|
# 只有没有禁用插件时才实例化插件类
|
||||||
try:
|
if plugin_config:
|
||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.config = plugin_config
|
||||||
context=self.context, config=plugin_config
|
try:
|
||||||
)
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
except TypeError as _:
|
context=self.context, config=plugin_config
|
||||||
|
)
|
||||||
|
except TypeError as _:
|
||||||
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
|
context=self.context
|
||||||
|
)
|
||||||
|
else:
|
||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
context=self.context
|
context=self.context
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||||
|
|
||||||
metadata.module = module
|
metadata.module = module
|
||||||
metadata.root_dir_name = root_dir_name
|
metadata.root_dir_name = root_dir_name
|
||||||
metadata.reserved = reserved
|
metadata.reserved = reserved
|
||||||
|
|
||||||
|
# 更新插件的平台兼容性
|
||||||
|
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||||
|
"plugin_enable", {}
|
||||||
|
)
|
||||||
|
metadata.update_platform_compatibility(plugin_enable_config)
|
||||||
|
|
||||||
# 绑定 handler
|
# 绑定 handler
|
||||||
related_handlers = (
|
related_handlers = (
|
||||||
star_handlers_registry.get_handlers_by_module_name(
|
star_handlers_registry.get_handlers_by_module_name(
|
||||||
@@ -316,7 +431,10 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
# 绑定 llm_tool handler
|
# 绑定 llm_tool handler
|
||||||
for func_tool in llm_tools.func_list:
|
for func_tool in llm_tools.func_list:
|
||||||
if func_tool.handler.__module__ == metadata.module_path:
|
if (
|
||||||
|
func_tool.handler
|
||||||
|
and func_tool.handler.__module__ == metadata.module_path
|
||||||
|
):
|
||||||
func_tool.handler_module_path = metadata.module_path
|
func_tool.handler_module_path = metadata.module_path
|
||||||
func_tool.handler = functools.partial(
|
func_tool.handler = functools.partial(
|
||||||
func_tool.handler, metadata.star_cls
|
func_tool.handler, metadata.star_cls
|
||||||
@@ -331,19 +449,23 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
classes = self._get_classes(module)
|
classes = self._get_classes(module)
|
||||||
|
|
||||||
if plugin_config:
|
if path not in inactivated_plugins:
|
||||||
try:
|
# 只有没有禁用插件时才实例化插件类
|
||||||
obj = getattr(module, classes[0])(
|
if plugin_config:
|
||||||
context=self.context, config=plugin_config
|
try:
|
||||||
) # 实例化插件类
|
obj = getattr(module, classes[0])(
|
||||||
except TypeError as _:
|
context=self.context, config=plugin_config
|
||||||
|
) # 实例化插件类
|
||||||
|
except TypeError as _:
|
||||||
|
obj = getattr(module, classes[0])(
|
||||||
|
context=self.context
|
||||||
|
) # 实例化插件类
|
||||||
|
else:
|
||||||
obj = getattr(module, classes[0])(
|
obj = getattr(module, classes[0])(
|
||||||
context=self.context
|
context=self.context
|
||||||
) # 实例化插件类
|
) # 实例化插件类
|
||||||
else:
|
else:
|
||||||
obj = getattr(module, classes[0])(
|
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||||
context=self.context
|
|
||||||
) # 实例化插件类
|
|
||||||
|
|
||||||
metadata = None
|
metadata = None
|
||||||
metadata = self._load_plugin_metadata(
|
metadata = self._load_plugin_metadata(
|
||||||
@@ -424,12 +546,62 @@ class PluginManager:
|
|||||||
return False, fail_rec
|
return False, fail_rec
|
||||||
|
|
||||||
async def install_plugin(self, repo_url: str, proxy=""):
|
async def install_plugin(self, repo_url: str, proxy=""):
|
||||||
|
"""从仓库 URL 安装插件
|
||||||
|
|
||||||
|
从指定的仓库 URL 下载并安装插件,然后加载该插件到系统中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_url (str): 要安装的插件仓库 URL
|
||||||
|
proxy (str, optional): 用于下载的代理服务器。默认为空字符串。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict | None: 安装成功时返回包含插件信息的字典:
|
||||||
|
- repo: 插件的仓库 URL
|
||||||
|
- readme: README.md 文件的内容(如果存在)
|
||||||
|
如果找不到插件元数据则返回 None。
|
||||||
|
"""
|
||||||
plugin_path = await self.updator.install(repo_url, proxy)
|
plugin_path = await self.updator.install(repo_url, proxy)
|
||||||
# reload the plugin
|
# reload the plugin
|
||||||
await self.reload()
|
dir_name = os.path.basename(plugin_path)
|
||||||
return plugin_path
|
await self.load(specified_dir_name=dir_name)
|
||||||
|
|
||||||
|
# Get the plugin metadata to return repo info
|
||||||
|
plugin = self.context.get_registered_star(dir_name)
|
||||||
|
if not plugin:
|
||||||
|
# Try to find by other name if directory name doesn't match plugin name
|
||||||
|
for star in self.context.get_all_stars():
|
||||||
|
if star.root_dir_name == dir_name:
|
||||||
|
plugin = star
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract README.md content if exists
|
||||||
|
readme_content = None
|
||||||
|
readme_path = os.path.join(plugin_path, "README.md")
|
||||||
|
if not os.path.exists(readme_path):
|
||||||
|
readme_path = os.path.join(plugin_path, "readme.md")
|
||||||
|
|
||||||
|
if os.path.exists(readme_path):
|
||||||
|
try:
|
||||||
|
with open(readme_path, "r", encoding="utf-8") as f:
|
||||||
|
readme_content = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||||
|
|
||||||
|
plugin_info = None
|
||||||
|
if plugin:
|
||||||
|
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||||
|
|
||||||
|
return plugin_info
|
||||||
|
|
||||||
async def uninstall_plugin(self, plugin_name: str):
|
async def uninstall_plugin(self, plugin_name: str):
|
||||||
|
"""卸载指定的插件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): 要卸载的插件名称
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||||
|
"""
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
raise Exception("插件不存在。")
|
raise Exception("插件不存在。")
|
||||||
@@ -450,32 +622,46 @@ class PluginManager:
|
|||||||
# 从 star_registry 和 star_map 中删除
|
# 从 star_registry 和 star_map 中删除
|
||||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||||
|
|
||||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
try:
|
||||||
|
remove_dir(os.path.join(ppath, root_dir_name))
|
||||||
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||||
|
"""解绑并移除一个插件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name: 要解绑的插件名称
|
||||||
|
plugin_module_path: 插件的完整模块路径
|
||||||
|
"""
|
||||||
|
plugin = None
|
||||||
del star_map[plugin_module_path]
|
del star_map[plugin_module_path]
|
||||||
for i, p in enumerate(star_registry):
|
for i, p in enumerate(star_registry):
|
||||||
if p.name == plugin_name:
|
if p.name == plugin_name:
|
||||||
|
plugin = p
|
||||||
del star_registry[i]
|
del star_registry[i]
|
||||||
break
|
break
|
||||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||||
plugin_module_path
|
plugin_module_path
|
||||||
):
|
):
|
||||||
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
|
logger.info(
|
||||||
|
f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})"
|
||||||
|
)
|
||||||
star_handlers_registry.remove(handler)
|
star_handlers_registry.remove(handler)
|
||||||
keys_to_delete = [
|
|
||||||
|
for k in [
|
||||||
k
|
k
|
||||||
for k, v in star_handlers_registry.star_handlers_map.items()
|
for k in star_handlers_registry.star_handlers_map
|
||||||
if k.startswith(plugin_module_path)
|
if k.startswith(plugin_module_path)
|
||||||
]
|
]:
|
||||||
for k in keys_to_delete:
|
|
||||||
v = star_handlers_registry.star_handlers_map[k]
|
|
||||||
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
|
|
||||||
del star_handlers_registry.star_handlers_map[k]
|
del star_handlers_registry.star_handlers_map[k]
|
||||||
|
|
||||||
|
self._purge_modules(
|
||||||
|
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||||
|
)
|
||||||
|
|
||||||
async def update_plugin(self, plugin_name: str, proxy=""):
|
async def update_plugin(self, plugin_name: str, proxy=""):
|
||||||
"""升级一个插件"""
|
"""升级一个插件"""
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
@@ -485,7 +671,7 @@ class PluginManager:
|
|||||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||||
|
|
||||||
await self.updator.update(plugin, proxy=proxy)
|
await self.updator.update(plugin, proxy=proxy)
|
||||||
await self.reload()
|
await self.reload(plugin_name)
|
||||||
|
|
||||||
async def turn_off_plugin(self, plugin_name: str):
|
async def turn_off_plugin(self, plugin_name: str):
|
||||||
"""
|
"""
|
||||||
@@ -524,11 +710,18 @@ class PluginManager:
|
|||||||
|
|
||||||
async def _terminate_plugin(self, star_metadata: StarMetadata):
|
async def _terminate_plugin(self, star_metadata: StarMetadata):
|
||||||
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
||||||
logging.info(f"正在终止插件 {star_metadata.name} ...")
|
logger.info(f"正在终止插件 {star_metadata.name} ...")
|
||||||
|
|
||||||
|
if not star_metadata.activated:
|
||||||
|
# 说明之前已经被禁用了
|
||||||
|
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
|
||||||
|
return
|
||||||
|
|
||||||
if hasattr(star_metadata.star_cls, "__del__"):
|
if hasattr(star_metadata.star_cls, "__del__"):
|
||||||
asyncio.get_event_loop().run_in_executor(star_metadata.star_cls.__del__)
|
asyncio.get_event_loop().run_in_executor(
|
||||||
else:
|
None, star_metadata.star_cls.__del__
|
||||||
|
)
|
||||||
|
elif hasattr(star_metadata.star_cls, "terminate"):
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
async def turn_on_plugin(self, plugin_name: str):
|
async def turn_on_plugin(self, plugin_name: str):
|
||||||
@@ -541,12 +734,17 @@ class PluginManager:
|
|||||||
|
|
||||||
# 启用插件启用的 llm_tool
|
# 启用插件启用的 llm_tool
|
||||||
for func_tool in llm_tools.func_list:
|
for func_tool in llm_tools.func_list:
|
||||||
if func_tool.handler_module_path == plugin.module_path:
|
if (
|
||||||
|
func_tool.handler_module_path == plugin.module_path
|
||||||
|
and func_tool.name in inactivated_llm_tools
|
||||||
|
):
|
||||||
inactivated_llm_tools.remove(func_tool.name)
|
inactivated_llm_tools.remove(func_tool.name)
|
||||||
func_tool.active = True
|
func_tool.active = True
|
||||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||||
|
|
||||||
plugin.activated = True
|
await self.reload(plugin_name)
|
||||||
|
|
||||||
|
# plugin.activated = True
|
||||||
|
|
||||||
async def install_plugin_from_file(self, zip_file_path: str):
|
async def install_plugin_from_file(self, zip_file_path: str):
|
||||||
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
||||||
@@ -559,4 +757,33 @@ class PluginManager:
|
|||||||
os.remove(zip_file_path)
|
os.remove(zip_file_path)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||||
await self.reload()
|
# await self.reload()
|
||||||
|
await self.load(specified_dir_name=dir_name)
|
||||||
|
|
||||||
|
# Get the plugin metadata to return repo info
|
||||||
|
plugin = self.context.get_registered_star(dir_name)
|
||||||
|
if not plugin:
|
||||||
|
# Try to find by other name if directory name doesn't match plugin name
|
||||||
|
for star in self.context.get_all_stars():
|
||||||
|
if star.root_dir_name == dir_name:
|
||||||
|
plugin = star
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract README.md content if exists
|
||||||
|
readme_content = None
|
||||||
|
readme_path = os.path.join(desti_dir, "README.md")
|
||||||
|
if not os.path.exists(readme_path):
|
||||||
|
readme_path = os.path.join(desti_dir, "readme.md")
|
||||||
|
|
||||||
|
if os.path.exists(readme_path):
|
||||||
|
try:
|
||||||
|
with open(readme_path, "r", encoding="utf-8") as f:
|
||||||
|
readme_content = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||||
|
|
||||||
|
plugin_info = None
|
||||||
|
if plugin:
|
||||||
|
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||||
|
|
||||||
|
return plugin_info
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user