Compare commits
894 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3b0ca8ef6 | ||
|
|
9e266eb6d5 | ||
|
|
7231403e16 | ||
|
|
344a486fd7 | ||
|
|
4fd831875d | ||
|
|
0988d067ea | ||
|
|
3b6dd7e15a | ||
|
|
757d2a3947 | ||
|
|
61b71143f2 | ||
|
|
1b343a36c9 | ||
|
|
8e94937060 | ||
|
|
a4f212a18f | ||
|
|
caafb73190 | ||
|
|
09482799c9 | ||
|
|
37f93d1760 | ||
|
|
725f2e5204 | ||
|
|
967198fae0 | ||
|
|
43d57f6dcb | ||
|
|
6afa4db577 | ||
|
|
3b8c3fb29a | ||
|
|
921c3b0627 | ||
|
|
c0fadb45ab | ||
|
|
a1481fb179 | ||
|
|
987cd972d3 | ||
|
|
bdf25976a3 | ||
|
|
87c3aff4ce | ||
|
|
99350a957a | ||
|
|
319068dc7e | ||
|
|
cd18806c39 | ||
|
|
95b08b2023 | ||
|
|
0e70f76c86 | ||
|
|
4d414a2994 | ||
|
|
3d22772d4e | ||
|
|
0b381e2570 | ||
|
|
f2cc4311c5 | ||
|
|
e349671fdf | ||
|
|
01c02d5efa | ||
|
|
b62b1f3870 | ||
|
|
8844830859 | ||
|
|
0c51ee4b64 | ||
|
|
11920d5e31 | ||
|
|
848ea1eb63 | ||
|
|
a216519486 | ||
|
|
b04606c38e | ||
|
|
38072beea7 | ||
|
|
b843f1fa03 | ||
|
|
560d40e571 | ||
|
|
5f0b8161b7 | ||
|
|
062d482917 | ||
|
|
7cd1eeac30 | ||
|
|
bafa473c8e | ||
|
|
750cf46b2e | ||
|
|
68885a4bbc | ||
|
|
bcc99a8904 | ||
|
|
59fbd98db3 | ||
|
|
b70ed425f1 | ||
|
|
45ef5811c8 | ||
|
|
3b137ac762 | ||
|
|
1ddb0caf73 | ||
|
|
ae4c6fe2dd | ||
|
|
db257af58e | ||
|
|
735368c71b | ||
|
|
9e04e3679b | ||
|
|
43b8414727 | ||
|
|
5a00187147 | ||
|
|
cb525c7c84 | ||
|
|
d88420dd03 | ||
|
|
b9a983f8e0 | ||
|
|
42431ea7db | ||
|
|
f9459e4abb | ||
|
|
72f917d611 | ||
|
|
9fd1d19e93 | ||
|
|
062af1ac08 | ||
|
|
41bd76e091 | ||
|
|
cfd3f4b199 | ||
|
|
79d38f9597 | ||
|
|
b3866559e1 | ||
|
|
4d186baa35 | ||
|
|
8ed3d5f3db | ||
|
|
f0c8f39b6d | ||
|
|
431db8fc9b | ||
|
|
ba252c5356 | ||
|
|
a2812c39c0 | ||
|
|
0490758820 | ||
|
|
9b36a5c8a6 | ||
|
|
c1cf2be533 | ||
|
|
109650faf3 | ||
|
|
e54eaab842 | ||
|
|
43b6297b5d | ||
|
|
c20f4f5adf | ||
|
|
dc1f222cd2 | ||
|
|
c2b687212c | ||
|
|
849913276d | ||
|
|
23579c1e4a | ||
|
|
e031161fd4 | ||
|
|
4800ee6c0a | ||
|
|
d3a7fef9b0 | ||
|
|
40822fe77a | ||
|
|
837b670213 | ||
|
|
57ce69f3fb | ||
|
|
be022c4894 | ||
|
|
8a366964bb | ||
|
|
ee86b68470 | ||
|
|
60352307aa | ||
|
|
3ebd2f746f | ||
|
|
1c1a65b637 | ||
|
|
010e60d029 | ||
|
|
7a25568861 | ||
|
|
5f4f913661 | ||
|
|
ccd0e34a53 | ||
|
|
72f1ffccd3 | ||
|
|
ea7a52945f | ||
|
|
89d4d1351a | ||
|
|
b757c91d93 | ||
|
|
27203d7a4d | ||
|
|
9ad4e18ac5 | ||
|
|
fcdc8f3ce7 | ||
|
|
78b994b84a | ||
|
|
58bfc677e2 | ||
|
|
7d17285a0c | ||
|
|
e9eb00a0d4 | ||
|
|
48d07af574 | ||
|
|
2fc62efd88 | ||
|
|
be516d75bd | ||
|
|
951d5fde85 | ||
|
|
1389abc052 | ||
|
|
19ad67a77f | ||
|
|
641f308344 | ||
|
|
9f097fa4d5 | ||
|
|
5ad362c52b | ||
|
|
614f238a61 | ||
|
|
dec91950bc | ||
|
|
6cef9c23f0 | ||
|
|
3f568bf136 | ||
|
|
5484b421ce | ||
|
|
02f21e07d3 | ||
|
|
fff1f23a83 | ||
|
|
a056ec0d38 | ||
|
|
2eb9e5dde3 | ||
|
|
627d2a4701 | ||
|
|
76895fe86d | ||
|
|
64c3c85780 | ||
|
|
7288348857 | ||
|
|
62e73299b1 | ||
|
|
fe76c41ed8 | ||
|
|
1a92edf8be | ||
|
|
b63b606a4e | ||
|
|
8e2ef3d22b | ||
|
|
c6c4a32283 | ||
|
|
b70b3b158e | ||
|
|
3d59ab8108 | ||
|
|
b6c3089510 | ||
|
|
bd92aac280 | ||
|
|
5299e802e9 | ||
|
|
8e5a57d7dd | ||
|
|
beaa324fb6 | ||
|
|
79e64fe206 | ||
|
|
93f525e3fe | ||
|
|
aacb803c64 | ||
|
|
8a0665b222 | ||
|
|
20e41a7f73 | ||
|
|
93a1699a35 | ||
|
|
c33c07e4af | ||
|
|
c7484d0cc9 | ||
|
|
fb85a7bb35 | ||
|
|
42ff9a4d34 | ||
|
|
005e9eae7c | ||
|
|
3e325debcc | ||
|
|
a221de9a2b | ||
|
|
32b0cc1865 | ||
|
|
bbf85f8a12 | ||
|
|
67a0172b28 | ||
|
|
fb19d4d45b | ||
|
|
a156b1af14 | ||
|
|
a604b4943c | ||
|
|
3f0b6435d9 | ||
|
|
e0f029e2cb | ||
|
|
89d3fd5fab | ||
|
|
a38b00be6b | ||
|
|
0e8d52b591 | ||
|
|
298c77740d | ||
|
|
c681aae8ee | ||
|
|
faef98b089 | ||
|
|
84a3e0a30b | ||
|
|
69bd553ce0 | ||
|
|
fd0c0f8975 | ||
|
|
860ceb06b4 | ||
|
|
ecf501bf72 | ||
|
|
81a2ed1e25 | ||
|
|
76ab28338a | ||
|
|
9a56c9630f | ||
|
|
53b9497c18 | ||
|
|
750b16b6ee | ||
|
|
0ee3e0779a | ||
|
|
333c2d9299 | ||
|
|
ad37ff5048 | ||
|
|
33f86f3bde | ||
|
|
8acb969a49 | ||
|
|
b74b5933b8 | ||
|
|
681c556b7e | ||
|
|
1746684e52 | ||
|
|
0b93d06555 | ||
|
|
8a8b8c7c27 | ||
|
|
6b6577006d | ||
|
|
23ee5e81c9 | ||
|
|
483f55e4b1 | ||
|
|
1bb1bc2553 | ||
|
|
a4e4e36f94 | ||
|
|
6849415812 | ||
|
|
86f6cb038e | ||
|
|
7480a1d6ce | ||
|
|
3cd10117dd | ||
|
|
0caf19d390 | ||
|
|
5c14ebb049 | ||
|
|
9717a736b1 | ||
|
|
9c9ab50d1a | ||
|
|
d4bcb8174e | ||
|
|
9e7fe773bd | ||
|
|
aca18fab0f | ||
|
|
691de01b79 | ||
|
|
3383f15142 | ||
|
|
84c1593889 | ||
|
|
3c80fa1e33 | ||
|
|
06b16a1deb | ||
|
|
4c4246fb09 | ||
|
|
364be1e9f6 | ||
|
|
f959ed71aa | ||
|
|
5c4326c302 | ||
|
|
125fc3a622 | ||
|
|
6b9e785db3 | ||
|
|
25d34e9a43 | ||
|
|
457d4aa1dc | ||
|
|
ff0c0992ff | ||
|
|
d379e012c4 | ||
|
|
151fff26fd | ||
|
|
3d0d561215 | ||
|
|
22d586ed7b | ||
|
|
6dc19b29e8 | ||
|
|
50975a87d4 | ||
|
|
ce721d9f0f | ||
|
|
20510a33f7 | ||
|
|
3abd9c8763 | ||
|
|
e9eff7420b | ||
|
|
64c250c9d8 | ||
|
|
8047f82bfd | ||
|
|
af6467fb3d | ||
|
|
3ff1664aec | ||
|
|
34ea2b44b8 | ||
|
|
6c8d851109 | ||
|
|
d678299a74 | ||
|
|
7aed0db2b6 | ||
|
|
0355524345 | ||
|
|
0a43e4672e | ||
|
|
71e0ccdfec | ||
|
|
1df33ac3c8 | ||
|
|
7334090ac1 | ||
|
|
6b0f044198 | ||
|
|
ddf54c9cf8 | ||
|
|
7c64e184e2 | ||
|
|
a904db033c | ||
|
|
b234856b02 | ||
|
|
89d51d2afc | ||
|
|
37cb9678e9 | ||
|
|
0500ff333a | ||
|
|
08528510ef | ||
|
|
ddbd03dc1e | ||
|
|
ade87f378a | ||
|
|
4db14b905f | ||
|
|
b669b31451 | ||
|
|
1cb2b62f81 | ||
|
|
e5828713cf | ||
|
|
d10cb84068 | ||
|
|
4222f8516f | ||
|
|
7f998c7611 | ||
|
|
db46000337 | ||
|
|
1aac8d8041 | ||
|
|
c59c8e05f7 | ||
|
|
4942d0a629 | ||
|
|
873b7715f4 | ||
|
|
98e7ed6920 | ||
|
|
046f5e645e | ||
|
|
f5e5a7094c | ||
|
|
154125fee6 | ||
|
|
9f8e960ebe | ||
|
|
4179b0be0a | ||
|
|
28bafa38db | ||
|
|
b07552565e | ||
|
|
c4427471d2 | ||
|
|
08f81c6784 | ||
|
|
a471e98aca | ||
|
|
75a8fcc8a0 | ||
|
|
46ef76c168 | ||
|
|
66637446c9 | ||
|
|
21efeb888a | ||
|
|
a4ee8b5322 | ||
|
|
36519ac47e | ||
|
|
3f514fceca | ||
|
|
c2249fdfac | ||
|
|
c610719a44 | ||
|
|
36a6c2461a | ||
|
|
c29f22c39e | ||
|
|
30d3062944 | ||
|
|
69ba75abf4 | ||
|
|
e4d486fec5 | ||
|
|
f242144dcf | ||
|
|
02dee2d664 | ||
|
|
a3dd2c3069 | ||
|
|
a23425e8aa | ||
|
|
be79ddc9a3 | ||
|
|
7d71015e8c | ||
|
|
ad54549b51 | ||
|
|
6cf032a164 | ||
|
|
6390d796ac | ||
|
|
98b8411905 | ||
|
|
ddf1029afa | ||
|
|
1effbc5cc9 | ||
|
|
414b645e9f | ||
|
|
398c76f496 | ||
|
|
1bc456dd95 | ||
|
|
2e8421884e | ||
|
|
70d9b193ac | ||
|
|
b49c11004a | ||
|
|
34843eea90 | ||
|
|
2d6d7f31e8 | ||
|
|
7a24cbff1c | ||
|
|
1e7eb2cf1c | ||
|
|
361256e016 | ||
|
|
8838dbd003 | ||
|
|
13a95e1f2b | ||
|
|
1aaa451a3e | ||
|
|
cbba81e54d | ||
|
|
370868dfac | ||
|
|
77f692aae2 | ||
|
|
9318e205ea | ||
|
|
ebcc717c19 | ||
|
|
4c16b564ee | ||
|
|
e2283d1453 | ||
|
|
d891801c5a | ||
|
|
de75386944 | ||
|
|
82dc37de50 | ||
|
|
b6fa7f62dc | ||
|
|
f9e0a95c5e | ||
|
|
b2c6e12647 | ||
|
|
caffb83780 | ||
|
|
8882cb5479 | ||
|
|
75dace2dee | ||
|
|
ad6487d042 | ||
|
|
a91604e8ab | ||
|
|
c364f7c643 | ||
|
|
53435ba184 | ||
|
|
25f8d5519b | ||
|
|
2e4fef6c66 | ||
|
|
80b2b7dc00 | ||
|
|
8585cd8e21 | ||
|
|
9fa2a7eeea | ||
|
|
2d1f74228d | ||
|
|
3d6f7aa0e1 | ||
|
|
3dea60366a | ||
|
|
d4d9a1df4c | ||
|
|
7d6975fd31 | ||
|
|
08be52ed17 | ||
|
|
682a7700c2 | ||
|
|
9d87009216 | ||
|
|
ef86838f62 | ||
|
|
35468233f8 | ||
|
|
26e229867d | ||
|
|
3a1578b3c6 | ||
|
|
d5e3d2cbbc | ||
|
|
c095248176 | ||
|
|
44601c8954 | ||
|
|
135dbb8f07 | ||
|
|
c95682a0c7 | ||
|
|
d177b9f7fa | ||
|
|
9b57615d94 | ||
|
|
c03f3eacd1 | ||
|
|
a26e395932 | ||
|
|
0870b87c96 | ||
|
|
b52a44a7dd | ||
|
|
0a290aafef | ||
|
|
9014d4c410 | ||
|
|
60e58b4f5f | ||
|
|
620e74a6aa | ||
|
|
efa287ed35 | ||
|
|
a24eb9d9b0 | ||
|
|
bd3dab8aae | ||
|
|
4fe1ebaa5b | ||
|
|
c5e944744b | ||
|
|
0c396181f7 | ||
|
|
0034474219 | ||
|
|
8136ad8287 | ||
|
|
681940d466 | ||
|
|
16488506e8 | ||
|
|
122fccc041 | ||
|
|
9d0ad35403 | ||
|
|
f9ec97e026 | ||
|
|
95495a2647 | ||
|
|
e3310a605c | ||
|
|
b55719bf28 | ||
|
|
b957b51279 | ||
|
|
90bcfab369 | ||
|
|
f8a8e30641 | ||
|
|
25cb98e7a7 | ||
|
|
03e1bb7cf9 | ||
|
|
85dbb24f3a | ||
|
|
d817635782 | ||
|
|
2f4f237810 | ||
|
|
5ac94d810f | ||
|
|
39dc46dc25 | ||
|
|
0d9cf725f7 | ||
|
|
e55dbead5b | ||
|
|
7d046e5b30 | ||
|
|
8b4693cf66 | ||
|
|
a1172c9a82 | ||
|
|
1ed2bd33f0 | ||
|
|
4c159bd0ba | ||
|
|
050654b2a9 | ||
|
|
61b261e1b2 | ||
|
|
017b010206 | ||
|
|
00f5189f58 | ||
|
|
4a8309ed1f | ||
|
|
76cfc31a1d | ||
|
|
d9ec434699 | ||
|
|
239f3c40be | ||
|
|
09c8c6e670 | ||
|
|
7e4ad01c94 | ||
|
|
ed98e269ef | ||
|
|
b47d63334f | ||
|
|
5e2a3a5aea | ||
|
|
1a7eb21fc7 | ||
|
|
834a51cdc9 | ||
|
|
1b69d99c06 | ||
|
|
ad189933c6 | ||
|
|
9d86ff32de | ||
|
|
278bb57a58 | ||
|
|
0ba494e0ba | ||
|
|
8b247054bb | ||
|
|
7c5c8e4e0d | ||
|
|
ad106a27f3 | ||
|
|
9d6f61b49e | ||
|
|
02368954a0 | ||
|
|
b477a35a01 | ||
|
|
16622887de | ||
|
|
9059d1fb17 | ||
|
|
df2b008d82 | ||
|
|
0da871efd0 | ||
|
|
1c55349f81 | ||
|
|
9309fa1e81 | ||
|
|
5996189f91 | ||
|
|
bd2b984bfb | ||
|
|
194409a117 | ||
|
|
27978b216d | ||
|
|
c38fa77ce6 | ||
|
|
3eb49f7422 | ||
|
|
1989d615d2 | ||
|
|
239412d265 | ||
|
|
375a419a9e | ||
|
|
875c8ab424 | ||
|
|
c9bfc810ce | ||
|
|
46ecb16949 | ||
|
|
f6dc16f17b | ||
|
|
4eef42f730 | ||
|
|
8612d9a771 | ||
|
|
0caff054f5 | ||
|
|
4aa91ad599 | ||
|
|
7a0864f5c2 | ||
|
|
73dc0dfcf6 | ||
|
|
1ff9a69339 | ||
|
|
179eb5d847 | ||
|
|
52c868828c | ||
|
|
7eea4615b6 | ||
|
|
d9b351df1a | ||
|
|
d6a785b645 | ||
|
|
79db828a01 | ||
|
|
a5ffb0f8dc | ||
|
|
9492fcde74 | ||
|
|
d2456ce4cd | ||
|
|
7de27abc8d | ||
|
|
d8155bc8eb | ||
|
|
cf08e52a92 | ||
|
|
768398b991 | ||
|
|
24c20a19f1 | ||
|
|
8fbcbcd4c0 | ||
|
|
e0da5bb943 | ||
|
|
36fbc4fb82 | ||
|
|
cb11051f42 | ||
|
|
a824781d14 | ||
|
|
600a2c6748 | ||
|
|
77df64bfb5 | ||
|
|
2d6e54903c | ||
|
|
baa2b83df9 | ||
|
|
1ff02446af | ||
|
|
b58c6ba762 | ||
|
|
611a902000 | ||
|
|
c1b3f9dd29 | ||
|
|
7c5a88a6a6 | ||
|
|
be9abfef58 | ||
|
|
b549c9377e | ||
|
|
a5b00dbf74 | ||
|
|
90e2e14cd7 | ||
|
|
14bb245424 | ||
|
|
b63a0f3a45 | ||
|
|
e1f8842d7f | ||
|
|
3dda5fb268 | ||
|
|
248e0c5240 | ||
|
|
0297a43de6 | ||
|
|
2b4f66e0cf | ||
|
|
e622af2cc3 | ||
|
|
f527b1b5a6 | ||
|
|
c15b13a107 | ||
|
|
bc06acdd25 | ||
|
|
5252870733 | ||
|
|
3cac6a47a5 | ||
|
|
49bba9bf98 | ||
|
|
f4d12e4e5e | ||
|
|
d305211a36 | ||
|
|
9ec44d6f97 | ||
|
|
175bb3ee01 | ||
|
|
036c78750f | ||
|
|
a18de9de7d | ||
|
|
59fbbd5987 | ||
|
|
7e89fbc907 | ||
|
|
0956f240b3 | ||
|
|
f9db97c6b0 | ||
|
|
a2443c4ac1 | ||
|
|
095bd95044 | ||
|
|
b569209647 | ||
|
|
9057cac2b9 | ||
|
|
f9a6c685df | ||
|
|
208eb4f454 | ||
|
|
b3cb9e6714 | ||
|
|
5f9233f9b7 | ||
|
|
16447ae597 | ||
|
|
103edd5260 | ||
|
|
928089bf0f | ||
|
|
e5bd74695a | ||
|
|
f796969465 | ||
|
|
10756175b7 | ||
|
|
5637a71486 | ||
|
|
bcebd0fb62 | ||
|
|
3817d3ca87 | ||
|
|
4dd714e814 | ||
|
|
61e8bb49ec | ||
|
|
103dcd3761 | ||
|
|
54ac135fc8 | ||
|
|
86582809fc | ||
|
|
974d648f19 | ||
|
|
a79afc9597 | ||
|
|
e4883241d9 | ||
|
|
babf223745 | ||
|
|
c7d91730b6 | ||
|
|
71246b65c9 | ||
|
|
50076b647e | ||
|
|
a1a788dce8 | ||
|
|
a611b4f346 | ||
|
|
7f6ed674b4 | ||
|
|
aa3cfd887a | ||
|
|
2649d46d8d | ||
|
|
e23ffe6f02 | ||
|
|
96f3c3729a | ||
|
|
11e9d47ce2 | ||
|
|
efbc8e4383 | ||
|
|
bc7404409f | ||
|
|
8677d70baf | ||
|
|
f39253f0e1 | ||
|
|
68c1957267 | ||
|
|
a275aa2e4d | ||
|
|
cadbac9948 | ||
|
|
82673e8ddd | ||
|
|
bee51024b3 | ||
|
|
3437cb73ec | ||
|
|
d01d1a8520 | ||
|
|
5aa842cf66 | ||
|
|
03282dee0f | ||
|
|
98e8ecb8e2 | ||
|
|
9451dc3fd4 | ||
|
|
e1d3759f55 | ||
|
|
0ec382c86b | ||
|
|
756087c9f1 | ||
|
|
3e7c47e873 | ||
|
|
e3ffdbc308 | ||
|
|
645cace4d6 | ||
|
|
0959d5986b | ||
|
|
89605c29a7 | ||
|
|
e527f31213 | ||
|
|
a0dbd99928 | ||
|
|
17d39c7a4a | ||
|
|
54edaebbd9 | ||
|
|
d587a6f64c | ||
|
|
2371c32be5 | ||
|
|
c9abb8352c | ||
|
|
8995e62e73 | ||
|
|
316147a8db | ||
|
|
1fdcfc7a30 | ||
|
|
8e2c633cd4 | ||
|
|
786b0e4a54 | ||
|
|
c38c1c3c35 | ||
|
|
7d856756f4 | ||
|
|
f0d1d365e0 | ||
|
|
8e2d666ff8 | ||
|
|
38d7be1d5f | ||
|
|
431e2fad72 | ||
|
|
b3b63be8fc | ||
|
|
071fc7d6ef | ||
|
|
2a37f7edac | ||
|
|
c656ad5e2c | ||
|
|
da14a89490 | ||
|
|
cf22eae467 | ||
|
|
b199bddb0b | ||
|
|
2188ea82de | ||
|
|
1fa13d0177 | ||
|
|
ed508af424 | ||
|
|
5df26864d5 | ||
|
|
837111b17e | ||
|
|
a6b363b433 | ||
|
|
2807e1e892 | ||
|
|
0a2abd8214 | ||
|
|
8beb7acdb1 | ||
|
|
466c80b94d | ||
|
|
36c0cfc9a9 | ||
|
|
35ba1b3345 | ||
|
|
d00821d1c7 | ||
|
|
6c1b3f242b | ||
|
|
9f9da1e0c9 | ||
|
|
14fb4b70bd | ||
|
|
b1049540a4 | ||
|
|
5e2909df33 | ||
|
|
c122dad21f | ||
|
|
48ae686602 | ||
|
|
bf2c3a1a81 | ||
|
|
96e7a93886 | ||
|
|
dba1ed1e19 | ||
|
|
a24514876b | ||
|
|
466a1c1c41 | ||
|
|
a2d5e9f40f | ||
|
|
1bbff1d161 | ||
|
|
0948bae99b | ||
|
|
850db41596 | ||
|
|
7bafc87e2b | ||
|
|
1a0de02a15 | ||
|
|
6d5d278624 | ||
|
|
3b4cc48fa0 | ||
|
|
c908461088 | ||
|
|
53d1398d30 | ||
|
|
782c0367d0 | ||
|
|
4678222e9b | ||
|
|
f71dc3e4be | ||
|
|
f6233893bd | ||
|
|
6427bcf130 | ||
|
|
8fa41b706c | ||
|
|
4706c4438d | ||
|
|
0c8ebc2b06 | ||
|
|
b3b5ebc2ca | ||
|
|
b8aa23ccc5 | ||
|
|
364843db29 | ||
|
|
aa56c8f7e6 | ||
|
|
8e9fd27058 | ||
|
|
b75908cb2a | ||
|
|
af6df49ce1 | ||
|
|
bd3bdb5769 | ||
|
|
98fe193b21 | ||
|
|
26cbc9e8b1 | ||
|
|
ebb8c43fd0 | ||
|
|
8c7344f1c4 | ||
|
|
5c32a17787 | ||
|
|
aff520e69a | ||
|
|
45e627c33c | ||
|
|
7a1b158f83 | ||
|
|
6374c5d49d | ||
|
|
fd460b19d4 | ||
|
|
dff7cc4ca5 | ||
|
|
d013320bec | ||
|
|
fc6dcfaf21 | ||
|
|
a001270bd2 | ||
|
|
9e67883fbd | ||
|
|
f1a448708c | ||
|
|
a4bfa96502 | ||
|
|
595b83a256 | ||
|
|
8d34f77321 | ||
|
|
67095f97b1 | ||
|
|
50740c94ab | ||
|
|
4db4cfeda2 | ||
|
|
ad13cef89c | ||
|
|
855fc6fcd1 | ||
|
|
8f12244e51 | ||
|
|
fe0213465c | ||
|
|
f984047004 | ||
|
|
19e9e2d090 | ||
|
|
7fe3b97d00 | ||
|
|
9cd243da47 | ||
|
|
e43208c2e9 | ||
|
|
dc016fc22f | ||
|
|
c6f037cae2 | ||
|
|
f049830e28 | ||
|
|
dd1995ae0b | ||
|
|
23dc233569 | ||
|
|
0977aa7d0d | ||
|
|
24862b0672 | ||
|
|
f05a57efc3 | ||
|
|
65331a9d7c | ||
|
|
f7ae287e40 | ||
|
|
45f380b1f6 | ||
|
|
9e6b329df4 | ||
|
|
43cd34d94c | ||
|
|
9fa00aff9a | ||
|
|
9a56dcb1be | ||
|
|
fdfe7bbe59 | ||
|
|
3a99a60792 | ||
|
|
fa2b4e14df | ||
|
|
35322a6900 | ||
|
|
2ccf29d61e | ||
|
|
b068013343 | ||
|
|
d839e72998 | ||
|
|
d7c9a8ed29 | ||
|
|
6837d4d692 | ||
|
|
8aba83735b | ||
|
|
aa51187747 | ||
|
|
5f07a9ae95 | ||
|
|
a2ca767bf4 | ||
|
|
5806c74e7c | ||
|
|
0481e1d45e | ||
|
|
3177b61421 | ||
|
|
6009cf5dfa | ||
|
|
0a970e8c31 | ||
|
|
aa276ca6af | ||
|
|
9f02dd13ff | ||
|
|
609e723322 | ||
|
|
c564a1d53e | ||
|
|
a7fe31f28b | ||
|
|
a84dc599d6 | ||
|
|
8da029add9 | ||
|
|
ba45a2d270 | ||
|
|
cb56b22aea | ||
|
|
23cc5b31ba | ||
|
|
e8d99f0460 | ||
|
|
6bcd10cd5c | ||
|
|
619fb20c5f | ||
|
|
386a312e96 | ||
|
|
2759d347e6 | ||
|
|
b6ec327b49 | ||
|
|
ee02d622ba | ||
|
|
5c4a6083f5 | ||
|
|
49e63a3d3d | ||
|
|
6bae9dc9ed | ||
|
|
5fa1979a46 | ||
|
|
b40d4fa315 | ||
|
|
4d2ff7cd5b | ||
|
|
d8ec0e64d0 | ||
|
|
82e979cc07 | ||
|
|
8c132a51f5 | ||
|
|
40bd372cc1 | ||
|
|
212e114270 | ||
|
|
b0e9de6951 | ||
|
|
3489522bbb | ||
|
|
96237abc03 | ||
|
|
7155b4f0ac | ||
|
|
a8b2b09e0f | ||
|
|
6858b8c555 | ||
|
|
0e493b1a0e | ||
|
|
37d478f970 | ||
|
|
7d0d42a49f | ||
|
|
0eb1684ef1 | ||
|
|
9b0b723143 | ||
|
|
532bc6e1e6 | ||
|
|
fe3ed4c454 | ||
|
|
b5ec89e586 | ||
|
|
895e7397c2 | ||
|
|
59b767957a | ||
|
|
17d4bf8f22 | ||
|
|
836be3b097 | ||
|
|
310415bea9 | ||
|
|
aafc1276a9 | ||
|
|
2993e794cc | ||
|
|
58cb9cfb2d | ||
|
|
fbdf0901d5 | ||
|
|
af8c81b621 | ||
|
|
06b5275e48 | ||
|
|
ad95572d5f | ||
|
|
0021cfc4bc | ||
|
|
aebc7850f4 | ||
|
|
1b7efbc607 | ||
|
|
3800e96d14 | ||
|
|
461f1bb07c | ||
|
|
7d4c07e4f6 | ||
|
|
31b788f463 | ||
|
|
96ab761f73 | ||
|
|
2b3f05c039 | ||
|
|
f2e8303b66 | ||
|
|
2a614b545b | ||
|
|
5c0ab21f68 | ||
|
|
689d109438 | ||
|
|
2a6934b283 | ||
|
|
760cb94e9a | ||
|
|
2a6cff0013 | ||
|
|
ce578f0417 | ||
|
|
1745bdb9e2 | ||
|
|
3f90b89c3c | ||
|
|
f343e40d15 | ||
|
|
5cc4be9e65 | ||
|
|
da5aada002 | ||
|
|
07f2ee9ad9 | ||
|
|
12f4e1146f | ||
|
|
92c57e5476 | ||
|
|
a923baacd8 | ||
|
|
999b094d55 | ||
|
|
d4213f2352 | ||
|
|
3f65c9a066 | ||
|
|
1d427e2645 | ||
|
|
36414c4b00 | ||
|
|
47e253d76c | ||
|
|
b73cf84df0 | ||
|
|
a5b885a774 | ||
|
|
0c785413da | ||
|
|
482d7ef5f7 | ||
|
|
9f9073c0ff | ||
|
|
ef05ff4abd | ||
|
|
5848aae435 | ||
|
|
fb06f33de0 | ||
|
|
0d7ddb149e | ||
|
|
4f2d7b9c4e | ||
|
|
c02ed96f6f | ||
|
|
3b2ac891b2 | ||
|
|
ef0108881b | ||
|
|
af48975a6b | ||
|
|
6441b149ab | ||
|
|
f8892881f8 | ||
|
|
228aec5401 | ||
|
|
68ad48ff55 | ||
|
|
541ba64032 | ||
|
|
2d870b798c | ||
|
|
0f1fe1ab63 | ||
|
|
73cc86ddb1 | ||
|
|
23128f4be2 | ||
|
|
92200d0e82 | ||
|
|
d6e8655792 | ||
|
|
37076d7920 | ||
|
|
78347ec91b | ||
|
|
9ded102a0a | ||
|
|
59b7d8b8cb | ||
|
|
f5b97f6762 | ||
|
|
d47da241af | ||
|
|
4611ce15eb | ||
|
|
aa8c56a688 | ||
|
|
ef44d4471a | ||
|
|
5581eae957 | ||
|
|
ec46dfaac9 | ||
|
|
6042a047bd | ||
|
|
6ca9e2a753 | ||
|
|
618eabfe5c | ||
|
|
bb5db2e9d0 | ||
|
|
97e4d169b3 | ||
|
|
50e44b1473 | ||
|
|
38588dd3fa | ||
|
|
d183388347 | ||
|
|
1e69d59384 | ||
|
|
00f008f94d | ||
|
|
3c28001a74 | ||
|
|
76a6218be6 | ||
|
|
6c1de1bbd6 | ||
|
|
d7678081da | ||
|
|
5e4ba563cb | ||
|
|
8afbe77b0a | ||
|
|
2ef139b59a | ||
|
|
1f0d2d9b89 | ||
|
|
37a1f144ab | ||
|
|
9a7a654596 | ||
|
|
9abccd63cf | ||
|
|
93fea77182 | ||
|
|
19797243f6 | ||
|
|
c9c733d925 | ||
|
|
a7d7678c78 | ||
|
|
c0911921c7 | ||
|
|
4a4241d57a | ||
|
|
c9426bb6eb | ||
|
|
db4abd169a | ||
|
|
80b6958599 | ||
|
|
80058c781a | ||
|
|
44bd2e36f3 | ||
|
|
3589a5e5be | ||
|
|
13ef033f0e | ||
|
|
3f8c68bbca | ||
|
|
4275cea82b | ||
|
|
a0bcb5339a | ||
|
|
43deec4a4b | ||
|
|
2bc433a30b | ||
|
|
eb2b395932 | ||
|
|
2bfd1c0bf2 | ||
|
|
7228c4b13f | ||
|
|
9351d7471f | ||
|
|
1cf49998bc | ||
|
|
6ae86597e8 | ||
|
|
c578ff25bd | ||
|
|
2934a3e3be | ||
|
|
ceaa69da75 | ||
|
|
fa8e731576 |
@@ -1,3 +0,0 @@
|
|||||||
comment:
|
|
||||||
layout: "condensed_header, condensed_files, condensed_footer"
|
|
||||||
hide_project_coverage: TRUE
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
[run]
|
|
||||||
omit =
|
|
||||||
*/site-packages/*
|
|
||||||
*/dist-packages/*
|
|
||||||
your_package_name/tests/*
|
|
||||||
40
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: '🥳 发布插件'
|
||||||
|
title: "[Plugin] 插件名"
|
||||||
|
description: 提交插件到插件市场
|
||||||
|
labels: [ "plugin-publish" ]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 插件仓库
|
||||||
|
description: 插件的 GitHub 仓库链接
|
||||||
|
placeholder: >
|
||||||
|
如 https://github.com/Soulter/astrbot-github-cards
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: 描述
|
||||||
|
value: |
|
||||||
|
插件名:
|
||||||
|
插件作者:
|
||||||
|
插件简介:
|
||||||
|
支持的消息平台:(必填,如 QQ、微信、飞书)
|
||||||
|
标签:(可选)
|
||||||
|
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
|
||||||
|
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Code of Conduct
|
||||||
|
options:
|
||||||
|
- label: >
|
||||||
|
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: "❤️"
|
||||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -28,7 +28,7 @@ body:
|
|||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: AstrBot 版本与部署方式
|
label: AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
|
||||||
description: >
|
description: >
|
||||||
请提供您的 AstrBot 版本和部署方式。
|
请提供您的 AstrBot 版本和部署方式。
|
||||||
placeholder: >
|
placeholder: >
|
||||||
@@ -53,9 +53,9 @@ body:
|
|||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: 额外信息
|
label: 报错日志
|
||||||
description: >
|
description: >
|
||||||
任何额外信息,如报错日志、截图等。
|
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
|
||||||
placeholder: >
|
placeholder: >
|
||||||
请提供完整的报错日志或截图。
|
请提供完整的报错日志或截图。
|
||||||
validations:
|
validations:
|
||||||
@@ -65,7 +65,7 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: 你愿意提交 PR 吗?
|
label: 你愿意提交 PR 吗?
|
||||||
description: >
|
description: >
|
||||||
这绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
||||||
options:
|
options:
|
||||||
- label: 是的,我愿意提交 PR!
|
- label: 是的,我愿意提交 PR!
|
||||||
|
|
||||||
|
|||||||
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -8,3 +8,7 @@
|
|||||||
### Modifications
|
### Modifications
|
||||||
|
|
||||||
<!--简单解释你的改动-->
|
<!--简单解释你的改动-->
|
||||||
|
|
||||||
|
### Check
|
||||||
|
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||||
|
- [ ] 我新增/修复/优化的功能经过良好的测试
|
||||||
|
|||||||
31
.github/workflows/dashboard_ci.yml
vendored
Normal file
31
.github/workflows/dashboard_ci.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
name: AstrBot Dashboard CI
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: npm install, build
|
||||||
|
run: |
|
||||||
|
cd dashboard
|
||||||
|
npm install
|
||||||
|
npm run build
|
||||||
|
|
||||||
|
- name: Inject Commit SHA
|
||||||
|
id: get_sha
|
||||||
|
run: |
|
||||||
|
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||||
|
mkdir -p dashboard/dist/assets
|
||||||
|
echo $COMMIT_SHA > dashboard/dist/assets/version
|
||||||
|
|
||||||
|
- name: Archive production artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist-without-markdown
|
||||||
|
path: |
|
||||||
|
dashboard/dist
|
||||||
|
!dist/**/*.md
|
||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,8 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
botpy.log
|
botpy.log
|
||||||
.vscode
|
.vscode
|
||||||
|
.venv*
|
||||||
|
.idea
|
||||||
data_v2.db
|
data_v2.db
|
||||||
data_v3.db
|
data_v3.db
|
||||||
configs/session
|
configs/session
|
||||||
@@ -17,8 +19,14 @@ addons/plugins
|
|||||||
|
|
||||||
tests/astrbot_plugin_openai
|
tests/astrbot_plugin_openai
|
||||||
chroma
|
chroma
|
||||||
node_modules/
|
dashboard/node_modules/
|
||||||
|
dashboard/dist/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
package-lock.json
|
package-lock.json
|
||||||
package.json
|
package.json
|
||||||
venv/*
|
venv/*
|
||||||
|
packages/python_interpreter/workplace
|
||||||
|
.venv/*
|
||||||
|
.conda/
|
||||||
|
.idea
|
||||||
|
pytest.ini
|
||||||
|
|||||||
13
.pre-commit-config.yaml
Normal file
13
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
default_install_hook_types: [pre-commit, prepare-commit-msg]
|
||||||
|
ci:
|
||||||
|
autofix_commit_msg: ":balloon: auto fixes by pre-commit hooks"
|
||||||
|
autofix_prs: true
|
||||||
|
autoupdate_branch: master
|
||||||
|
autoupdate_schedule: weekly
|
||||||
|
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.11.2
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
- id: ruff-format
|
||||||
12
Dockerfile
12
Dockerfile
@@ -9,10 +9,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
python3-dev \
|
python3-dev \
|
||||||
libffi-dev \
|
libffi-dev \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
ca-certificates \
|
||||||
|
bash \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN python -m pip install -r requirements.txt
|
RUN python -m pip install uv
|
||||||
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
|
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||||
|
|
||||||
|
# 释出 ffmpeg
|
||||||
|
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||||
|
|
||||||
|
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||||
|
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||||
|
|
||||||
EXPOSE 6185
|
EXPOSE 6185
|
||||||
EXPOSE 6186
|
EXPOSE 6186
|
||||||
|
|||||||
35
Dockerfile_with_node
Normal file
35
Dockerfile_with_node
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /AstrBot
|
||||||
|
|
||||||
|
COPY . /AstrBot/
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
build-essential \
|
||||||
|
python3-dev \
|
||||||
|
libffi-dev \
|
||||||
|
libssl-dev \
|
||||||
|
curl \
|
||||||
|
unzip \
|
||||||
|
ca-certificates \
|
||||||
|
bash \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Installation of Node.js
|
||||||
|
ENV NVM_DIR="/root/.nvm"
|
||||||
|
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||||
|
. "$NVM_DIR/nvm.sh" && \
|
||||||
|
nvm install 22 && \
|
||||||
|
nvm use 22
|
||||||
|
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
||||||
|
|
||||||
|
RUN python -m pip install uv
|
||||||
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
|
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
||||||
|
|
||||||
|
EXPOSE 6185
|
||||||
|
EXPOSE 6186
|
||||||
|
|
||||||
|
CMD ["python", "main.py"]
|
||||||
4
LICENSE
4
LICENSE
@@ -629,8 +629,8 @@ to attach them to the start of each source file to most effectively
|
|||||||
state the exclusion of warranty; and each file should have at least
|
state the exclusion of warranty; and each file should have at least
|
||||||
the "copyright" line and a pointer to where the full notice is found.
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
<one line to give the program's name and a brief idea of what it does.>
|
AstrBot is a llm-powered chatbot and develop framework.
|
||||||
Copyright (C) <year> <name of author>
|
Copyright (C) 2022-2099 Soulter
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
This program is free software: you can redistribute it and/or modify
|
||||||
it under the terms of the GNU Affero General Public License as published
|
it under the terms of the GNU Affero General Public License as published
|
||||||
|
|||||||
154
README.md
154
README.md
@@ -1,44 +1,57 @@
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/de10f24d-cd64-433a-90b8-16c0a60de24a" width=500>
|

|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
<h1>AstrBot</h1>
|
|
||||||
|
|
||||||
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
|
||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
|
||||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
|
||||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
|
||||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
|
||||||
</a>
|
|
||||||
|
|
||||||
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||||
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||||
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
<a href="https://astrbot.app/">查看文档</a> |
|
<a href="https://astrbot.app/">查看文档</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||||
|
|
||||||
|
[](https://gitcode.com/Soulter/AstrBot)
|
||||||
|
|
||||||
|
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||||
|
-->
|
||||||
|
|
||||||
|
## ✨ 近期更新
|
||||||
|
|
||||||
|
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||||
|
|
||||||
## ✨ 主要功能
|
## ✨ 主要功能
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||||
|
|
||||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat、VChat)、Telegram。后续将支持钉钉、飞书、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||||
>
|
>
|
||||||
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM,因此无法在聊天页使用大模型。
|
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||||
|
|
||||||
## ✨ 使用方式
|
## ✨ 使用方式
|
||||||
|
|
||||||
@@ -48,43 +61,87 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
|
|
||||||
#### Windows 一键安装器部署
|
#### Windows 一键安装器部署
|
||||||
|
|
||||||
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||||
|
|
||||||
#### Replit 部署
|
#### 宝塔面板部署
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||||
|
|
||||||
#### CasaOS 部署
|
#### CasaOS 部署
|
||||||
|
|
||||||
社区贡献的部署方式。
|
社区贡献的部署方式。
|
||||||
|
|
||||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
||||||
|
|
||||||
#### 手动部署
|
#### 手动部署
|
||||||
|
|
||||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
推荐使用 `uv`。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||||
|
pip install uv
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||||
|
|
||||||
|
#### Replit 部署
|
||||||
|
|
||||||
|
[](https://repl.it/github/Soulter/AstrBot)
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
|
|
||||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||||
| -------- | ------- | ------- | ------ |
|
| -------- | ------- | ------- | ------ |
|
||||||
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||||
| QQ 官方API | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||||
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
|
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
|
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||||
|
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
|
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||||
| 飞书 | 🚧 | 计划内 | - |
|
|
||||||
| Discord | 🚧 | 计划内 | - |
|
| Discord | 🚧 | 计划内 | - |
|
||||||
| WhatsApp | 🚧 | 计划内 | - |
|
| WhatsApp | 🚧 | 计划内 | - |
|
||||||
| 小爱音响 | 🚧 | 计划内 | - |
|
| 小爱音响 | 🚧 | 计划内 | - |
|
||||||
|
|
||||||
|
## ⚡ 提供商支持情况
|
||||||
|
|
||||||
|
| 名称 | 支持性 | 类型 | 备注 |
|
||||||
|
| -------- | ------- | ------- | ------- |
|
||||||
|
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||||
|
| Claude API | ✔ | 文本生成 | |
|
||||||
|
| Google Gemini API | ✔ | 文本生成 | |
|
||||||
|
| Dify | ✔ | LLMOps | |
|
||||||
|
| DashScope(阿里云百炼应用) | ✔ | LLMOps | |
|
||||||
|
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
|
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
|
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||||
|
| OneAPI | ✔ | LLM 分发系统 | |
|
||||||
|
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||||
|
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||||
|
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||||
|
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||||
|
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||||
|
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||||
|
|
||||||
## ❤️ 贡献
|
## ❤️ 贡献
|
||||||
|
|
||||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||||
|
|
||||||
对于新功能的添加,请先通过 Issue 讨论。
|
### 如何贡献
|
||||||
|
|
||||||
|
你可以通过查看问题或帮助审核 PR(拉取请求)来贡献。任何问题或 PR 都欢迎参与,以促进社区贡献。当然,这些只是建议,你可以以任何方式进行贡献。对于新功能的添加,请先通过 Issue 讨论。
|
||||||
|
|
||||||
|
### 开发环境
|
||||||
|
|
||||||
|
AstrBot 使用 `ruff` 进行代码格式化和检查。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/Soulter/AstrBot
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
## 🌟 支持
|
## 🌟 支持
|
||||||
|
|
||||||
@@ -94,55 +151,52 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
|
|
||||||
## ✨ Demo
|
## ✨ Demo
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
|
||||||
|
|
||||||
<div align='center'>
|
<div align='center'>
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||||
|
|
||||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||||
|
|
||||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
|
||||||
|
|
||||||
_✨ 自然语言待办事项 ✨_
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||||
|
|
||||||
_✨ 插件系统——部分插件展示 ✨_
|
_✨ 插件系统——部分插件展示 ✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
||||||
|
|
||||||
_✨ 管理面板 ✨_
|
_✨ WebUI ✨_
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
_✨ 内置 Web Chat,在线与机器人交互 ✨_
|
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
## ❤️ Special Thanks
|
||||||
|
|
||||||
|
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||||
|
|
||||||
|
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
[](https://star-history.com/#soulter/astrbot&Date)
|
[](https://star-history.com/#soulter/astrbot&Date)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- ## ✨ ATRI [Beta 测试]
|
## Disclaimer
|
||||||
|
|
||||||
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
1. The project is protected under the `AGPL-v3` opensource license.
|
||||||
|
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||||
|
3. Please ensure compliance with local laws and regulations when using this project.
|
||||||
|
|
||||||
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
|
_私は、高性能ですから!_
|
||||||
2. 长期记忆
|
|
||||||
3. 表情包理解与回复
|
|
||||||
4. TTS
|
|
||||||
-->
|
|
||||||
|
|
||||||
_アトリは、高性能ですから!_
|
|
||||||
|
|
||||||
|
|||||||
182
README_en.md
Normal file
182
README_en.md
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
<p align="center">
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||||
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||||
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
||||||
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||
|

|
||||||
|
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||||
|
|
||||||
|
<a href="https://astrbot.app/">Documentation</a> |
|
||||||
|
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
|
||||||
|
|
||||||
|
## ✨ Key Features
|
||||||
|
|
||||||
|
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
||||||
|
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
||||||
|
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
|
||||||
|
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
||||||
|
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
||||||
|
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||||
|
> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
|
||||||
|
|
||||||
|
## ✨ Deployment
|
||||||
|
|
||||||
|
#### Docker Deployment
|
||||||
|
|
||||||
|
See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
|
||||||
|
|
||||||
|
#### Windows Installer
|
||||||
|
|
||||||
|
Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
|
||||||
|
|
||||||
|
#### Replit Deployment
|
||||||
|
|
||||||
|
[](https://repl.it/github/Soulter/AstrBot)
|
||||||
|
|
||||||
|
#### CasaOS Deployment
|
||||||
|
|
||||||
|
Community-contributed method.
|
||||||
|
See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
|
||||||
|
|
||||||
|
#### Manual Deployment
|
||||||
|
|
||||||
|
See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
|
||||||
|
|
||||||
|
## ⚡ Platform Support
|
||||||
|
|
||||||
|
| Platform | Status | Details | Message Types |
|
||||||
|
| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
|
||||||
|
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
|
||||||
|
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||||
|
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||||
|
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
||||||
|
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
||||||
|
| Feishu | ✔ | Group chats | Text, Images |
|
||||||
|
| WeChat Open Platform | 🚧 | Planned | - |
|
||||||
|
| Discord | 🚧 | Planned | - |
|
||||||
|
| WhatsApp | 🚧 | Planned | - |
|
||||||
|
| Xiaomi Speakers | 🚧 | Planned | - |
|
||||||
|
|
||||||
|
## Provider Support Status
|
||||||
|
|
||||||
|
| Name | Support | Type | Notes |
|
||||||
|
|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
|
||||||
|
| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
|
||||||
|
| Claude API | ✔ | Text Generation | |
|
||||||
|
| Google Gemini API | ✔ | Text Generation | |
|
||||||
|
| Dify | ✔ | LLMOps | |
|
||||||
|
| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
|
||||||
|
| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
|
||||||
|
| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
|
||||||
|
| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
|
||||||
|
| OneAPI | ✔ | LLM Distribution | |
|
||||||
|
| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
|
||||||
|
| SenseVoice | ✔ | Speech-to-Text | Local deployment |
|
||||||
|
| OpenAI TTS API | ✔ | Text-to-Speech | |
|
||||||
|
| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
|
||||||
|
|
||||||
|
# 🦌 Roadmap
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Suggestions welcome via Issues <3
|
||||||
|
|
||||||
|
- [ ] Ensure feature parity across all platform adapters
|
||||||
|
- [ ] Optimize plugin APIs
|
||||||
|
- [ ] Add default TTS services (e.g., GPT-Sovits)
|
||||||
|
- [ ] Enhance chat features with persistent memory
|
||||||
|
- [ ] i18n Planning
|
||||||
|
|
||||||
|
## ❤️ Contributions
|
||||||
|
|
||||||
|
All Issues/PRs welcome! Simply submit your changes to this project :)
|
||||||
|
|
||||||
|
For major features, please discuss via Issues first.
|
||||||
|
|
||||||
|
## 🌟 Support
|
||||||
|
|
||||||
|
- Star this project!
|
||||||
|
- Support via [Afdian](https://afdian.com/a/soulter)
|
||||||
|
- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
|
||||||
|
|
||||||
|
## ✨ Demos
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
|
||||||
|
|
||||||
|
<div align='center'>
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||||
|
|
||||||
|
_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||||
|
|
||||||
|
_✨ Multimodal Input, Web Search, Text-to-Image ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||||
|
|
||||||
|
_✨ Natural Language TODO Lists ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||||
|
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||||
|
|
||||||
|
_✨ Plugin System Showcase ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
||||||
|
|
||||||
|
_✨ Web Dashboard ✨_
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
_✨ Built-in Web Chat Interface ✨_
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## ⭐ Star History
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If this project helps you, please give it a star <3
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[](https://star-history.com/#soulter/astrbot&Date)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Disclaimer
|
||||||
|
|
||||||
|
1. Licensed under `AGPL-v3`.
|
||||||
|
2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
|
||||||
|
3. Users must comply with local laws and regulations.
|
||||||
|
|
||||||
|
<!-- ## ✨ ATRI [Beta]
|
||||||
|
|
||||||
|
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||||
|
|
||||||
|
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
|
||||||
|
2. Long-term memory
|
||||||
|
3. Meme understanding & responses
|
||||||
|
4. TTS integration
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
170
README_ja.md
Normal file
170
README_ja.md
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
<p align="center">
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||||
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||||
|
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||||
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||
|

|
||||||
|
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||||
|
|
||||||
|
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
||||||
|
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
||||||
|
|
||||||
|
## ✨ 主な機能
|
||||||
|
|
||||||
|
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||||
|
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||||
|
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||||
|
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||||
|
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||||
|
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||||
|
>
|
||||||
|
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||||
|
|
||||||
|
## ✨ 使用方法
|
||||||
|
|
||||||
|
#### Docker デプロイ
|
||||||
|
|
||||||
|
公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください。
|
||||||
|
|
||||||
|
#### Windows ワンクリックインストーラーのデプロイ
|
||||||
|
|
||||||
|
コンピュータに Python(>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
|
||||||
|
|
||||||
|
#### Replit デプロイ
|
||||||
|
|
||||||
|
[](https://repl.it/github/Soulter/AstrBot)
|
||||||
|
|
||||||
|
#### CasaOS デプロイ
|
||||||
|
|
||||||
|
コミュニティが提供するデプロイ方法です。
|
||||||
|
|
||||||
|
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
|
||||||
|
|
||||||
|
#### 手動デプロイ
|
||||||
|
|
||||||
|
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
|
||||||
|
|
||||||
|
## ⚡ メッセージプラットフォームのサポート状況
|
||||||
|
|
||||||
|
| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
|
||||||
|
| -------- | ------- | ------- | ------ |
|
||||||
|
| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
|
||||||
|
| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||||
|
| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||||
|
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
|
||||||
|
| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
|
||||||
|
| Feishu | ✔ | グループチャット | テキスト、画像 |
|
||||||
|
| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
|
||||||
|
| Discord | 🚧 | 計画中 | - |
|
||||||
|
| WhatsApp | 🚧 | 計画中 | - |
|
||||||
|
| Xiaoai 音響 | 🚧 | 計画中 | - |
|
||||||
|
|
||||||
|
# 🦌 今後のロードマップ
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Issue でさらに多くの提案を歓迎します <3
|
||||||
|
|
||||||
|
- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
|
||||||
|
- [ ] プラグインインターフェースの最適化
|
||||||
|
- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
|
||||||
|
- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
|
||||||
|
- [ ] i18n の計画
|
||||||
|
|
||||||
|
## ❤️ 貢献
|
||||||
|
|
||||||
|
Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
|
||||||
|
|
||||||
|
新機能の追加については、まず Issue で議論してください。
|
||||||
|
|
||||||
|
## 🌟 サポート
|
||||||
|
|
||||||
|
- このプロジェクトに Star を付けてください!
|
||||||
|
- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
|
||||||
|
- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
|
||||||
|
|
||||||
|
## ✨ デモ
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
|
||||||
|
|
||||||
|
<div align='center'>
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||||
|
|
||||||
|
_✨ Docker ベースのサンドボックス化されたコードエグゼキューター(ベータテスト中)✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||||
|
|
||||||
|
_✨ 多モーダル、ウェブ検索、長文の画像変換(設定可能)✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||||
|
|
||||||
|
_✨ 自然言語タスク ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||||
|
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||||
|
|
||||||
|
_✨ プラグインシステム - 一部のプラグインの展示 ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width="600">
|
||||||
|
|
||||||
|
_✨ 管理パネル ✨_
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## ⭐ Star History
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[](https://star-history.com/#soulter/astrbot&Date)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## スポンサー
|
||||||
|
|
||||||
|
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||||
|
|
||||||
|
## 免責事項
|
||||||
|
|
||||||
|
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||||
|
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
||||||
|
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||||
|
|
||||||
|
<!-- ## ✨ ATRI [ベータテスト]
|
||||||
|
|
||||||
|
この機能はプラグインとしてロードされます。プラグインリポジトリのアドレス:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||||
|
|
||||||
|
1. 《ATRI ~ My Dear Moments》の主人公 ATRI のキャラクターセリフを微調整データセットとして使用した `Qwen1.5-7B-Chat Lora` 微調整モデル。
|
||||||
|
2. 長期記憶
|
||||||
|
3. ミームの理解と返信
|
||||||
|
4. TTS
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
from .core.log import LogManager
|
from .core.log import LogManager
|
||||||
logger = LogManager.GetLogger(log_name='astrbot')
|
|
||||||
|
logger = LogManager.GetLogger(log_name="astrbot")
|
||||||
|
|||||||
@@ -4,10 +4,4 @@ from astrbot.core import html_renderer
|
|||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]
|
||||||
"AstrBotConfig",
|
|
||||||
"logger",
|
|
||||||
"html_renderer",
|
|
||||||
"llm_tool",
|
|
||||||
"sp"
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import html_renderer
|
from astrbot.core import html_renderer
|
||||||
@@ -6,7 +5,10 @@ from astrbot.core.star.register import register_llm_tool as llm_tool
|
|||||||
|
|
||||||
# event
|
# event
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageEventResult, MessageChain, CommandResult, EventResultType
|
MessageEventResult,
|
||||||
|
MessageChain,
|
||||||
|
CommandResult,
|
||||||
|
EventResultType,
|
||||||
)
|
)
|
||||||
from astrbot.core.platform import AstrMessageEvent
|
from astrbot.core.platform import AstrMessageEvent
|
||||||
|
|
||||||
@@ -18,10 +20,16 @@ from astrbot.core.star.register import (
|
|||||||
register_regex as regex,
|
register_regex as regex,
|
||||||
register_platform_adapter_type as platform_adapter_type,
|
register_platform_adapter_type as platform_adapter_type,
|
||||||
)
|
)
|
||||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
from astrbot.core.star.filter.event_message_type import (
|
||||||
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
EventMessageTypeFilter,
|
||||||
|
EventMessageType,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.filter.platform_adapter_type import (
|
||||||
|
PlatformAdapterTypeFilter,
|
||||||
|
PlatformAdapterType,
|
||||||
|
)
|
||||||
from astrbot.core.star.register import (
|
from astrbot.core.star.register import (
|
||||||
register_star as register # 注册插件(Star)
|
register_star as register, # 注册插件(Star)
|
||||||
)
|
)
|
||||||
from astrbot.core.star import Context, Star
|
from astrbot.core.star import Context, Star
|
||||||
from astrbot.core.star.config import *
|
from astrbot.core.star.config import *
|
||||||
@@ -32,7 +40,12 @@ from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
|||||||
|
|
||||||
# platform
|
# platform
|
||||||
from astrbot.core.platform import (
|
from astrbot.core.platform import (
|
||||||
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
AstrMessageEvent,
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.platform.register import register_platform_adapter
|
from astrbot.core.platform.register import register_platform_adapter
|
||||||
|
|||||||
@@ -5,31 +5,45 @@ from astrbot.core.star.register import (
|
|||||||
register_regex as regex,
|
register_regex as regex,
|
||||||
register_platform_adapter_type as platform_adapter_type,
|
register_platform_adapter_type as platform_adapter_type,
|
||||||
register_permission_type as permission_type,
|
register_permission_type as permission_type,
|
||||||
|
register_custom_filter as custom_filter,
|
||||||
|
register_on_astrbot_loaded as on_astrbot_loaded,
|
||||||
register_on_llm_request as on_llm_request,
|
register_on_llm_request as on_llm_request,
|
||||||
|
register_on_llm_response as on_llm_response,
|
||||||
register_llm_tool as llm_tool,
|
register_llm_tool as llm_tool,
|
||||||
register_on_decorating_result as on_decorating_result,
|
register_on_decorating_result as on_decorating_result,
|
||||||
register_after_message_sent as after_message_sent
|
register_after_message_sent as after_message_sent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
from astrbot.core.star.filter.event_message_type import (
|
||||||
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
EventMessageTypeFilter,
|
||||||
|
EventMessageType,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.filter.platform_adapter_type import (
|
||||||
|
PlatformAdapterTypeFilter,
|
||||||
|
PlatformAdapterType,
|
||||||
|
)
|
||||||
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
|
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
|
||||||
|
from astrbot.core.star.filter.custom_filter import CustomFilter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'command',
|
"command",
|
||||||
'command_group',
|
"command_group",
|
||||||
'event_message_type',
|
"event_message_type",
|
||||||
'regex',
|
"regex",
|
||||||
'platform_adapter_type',
|
"platform_adapter_type",
|
||||||
'permission_type',
|
"permission_type",
|
||||||
'EventMessageTypeFilter',
|
"EventMessageTypeFilter",
|
||||||
'EventMessageType',
|
"EventMessageType",
|
||||||
'PlatformAdapterTypeFilter',
|
"PlatformAdapterTypeFilter",
|
||||||
'PlatformAdapterType',
|
"PlatformAdapterType",
|
||||||
'PermissionTypeFilter',
|
"PermissionTypeFilter",
|
||||||
'PermissionType',
|
"CustomFilter",
|
||||||
'on_llm_request',
|
"custom_filter",
|
||||||
'llm_tool',
|
"PermissionType",
|
||||||
'on_decorating_result',
|
"on_astrbot_loaded",
|
||||||
'after_message_sent'
|
"on_llm_request",
|
||||||
|
"llm_tool",
|
||||||
|
"on_decorating_result",
|
||||||
|
"after_message_sent",
|
||||||
|
"on_llm_response",
|
||||||
]
|
]
|
||||||
@@ -1,5 +1,23 @@
|
|||||||
from astrbot.core.platform import (
|
from astrbot.core.platform import (
|
||||||
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
AstrMessageEvent,
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
Group,
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.platform.register import register_platform_adapter
|
from astrbot.core.platform.register import register_platform_adapter
|
||||||
|
from astrbot.core.message.components import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AstrMessageEvent",
|
||||||
|
"Platform",
|
||||||
|
"AstrBotMessage",
|
||||||
|
"MessageMember",
|
||||||
|
"MessageType",
|
||||||
|
"PlatformMetadata",
|
||||||
|
"register_platform_adapter",
|
||||||
|
"Group",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,2 +1,17 @@
|
|||||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
|
from astrbot.core.provider.entities import (
|
||||||
|
ProviderRequest,
|
||||||
|
ProviderType,
|
||||||
|
ProviderMetaData,
|
||||||
|
LLMResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Provider",
|
||||||
|
"STTProvider",
|
||||||
|
"Personality",
|
||||||
|
"ProviderRequest",
|
||||||
|
"ProviderType",
|
||||||
|
"ProviderMetaData",
|
||||||
|
"LLMResponse",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from astrbot.core.star.register import (
|
from astrbot.core.star.register import (
|
||||||
register_star as register # 注册插件(Star)
|
register_star as register, # 注册插件(Star)
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.star import Context, Star
|
from astrbot.core.star import Context, Star, StarTools
|
||||||
from astrbot.core.star.config import *
|
from astrbot.core.star.config import *
|
||||||
|
|
||||||
|
__all__ = ["register", "Context", "Star", "StarTools"]
|
||||||
|
|||||||
7
astrbot/api/util/__init__.py
Normal file
7
astrbot/api/util/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from astrbot.core.utils.session_waiter import (
|
||||||
|
SessionWaiter,
|
||||||
|
SessionController,
|
||||||
|
session_waiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from .log import LogManager, LogBroker
|
from .log import LogManager, LogBroker # noqa
|
||||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||||
from astrbot.core.utils.pip_installer import PipInstaller
|
from astrbot.core.utils.pip_installer import PipInstaller
|
||||||
@@ -8,18 +8,23 @@ from astrbot.core.db.sqlite import SQLiteDatabase
|
|||||||
from astrbot.core.config.default import DB_PATH
|
from astrbot.core.config.default import DB_PATH
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
|
||||||
|
# 初始化数据存储文件夹
|
||||||
os.makedirs("data", exist_ok=True)
|
os.makedirs("data", exist_ok=True)
|
||||||
|
|
||||||
astrbot_config = AstrBotConfig()
|
astrbot_config = AstrBotConfig()
|
||||||
html_renderer = HtmlRenderer()
|
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||||
logger = LogManager.GetLogger(log_name='astrbot')
|
html_renderer = HtmlRenderer(t2i_base_url)
|
||||||
|
logger = LogManager.GetLogger(log_name="astrbot")
|
||||||
|
|
||||||
if os.environ.get('TESTING', ""):
|
if os.environ.get("TESTING", ""):
|
||||||
logger.setLevel('DEBUG')
|
logger.setLevel("DEBUG")
|
||||||
|
|
||||||
db_helper = SQLiteDatabase(DB_PATH)
|
db_helper = SQLiteDatabase(DB_PATH)
|
||||||
sp = SharedPreferences() # 简单的偏好设置存储
|
sp = (
|
||||||
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
|
SharedPreferences()
|
||||||
|
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||||
|
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
||||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||||
|
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||||
|
|||||||
@@ -1,2 +1,9 @@
|
|||||||
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
|
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
|
||||||
from .astrbot_config import *
|
from .astrbot_config import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
"VERSION",
|
||||||
|
"DB_PATH",
|
||||||
|
"AstrBotConfig",
|
||||||
|
]
|
||||||
|
|||||||
@@ -2,42 +2,88 @@ import os
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import enum
|
import enum
|
||||||
from .default import DEFAULT_CONFIG
|
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
|
|
||||||
class RateLimitStrategy(enum.Enum):
|
class RateLimitStrategy(enum.Enum):
|
||||||
STALL = "stall"
|
STALL = "stall"
|
||||||
DISCARD = "discard"
|
DISCARD = "discard"
|
||||||
|
|
||||||
|
|
||||||
class AstrBotConfig(dict):
|
class AstrBotConfig(dict):
|
||||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
|
"""从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
|
||||||
|
|
||||||
def __init__(self):
|
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
|
||||||
|
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
|
||||||
|
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||||
|
default_config: dict = DEFAULT_CONFIG,
|
||||||
|
schema: dict = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not self.check_exist():
|
|
||||||
'''不存在时载入默认配置'''
|
|
||||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
|
||||||
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
|
|
||||||
|
|
||||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
|
||||||
|
object.__setattr__(self, "config_path", config_path)
|
||||||
|
object.__setattr__(self, "default_config", default_config)
|
||||||
|
object.__setattr__(self, "schema", schema)
|
||||||
|
|
||||||
|
if schema:
|
||||||
|
default_config = self._config_schema_to_default_config(schema)
|
||||||
|
|
||||||
|
if not self.check_exist():
|
||||||
|
"""不存在时载入默认配置"""
|
||||||
|
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||||
|
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||||
conf_str = f.read()
|
conf_str = f.read()
|
||||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
if conf_str.startswith("/ufeff"): # remove BOM
|
||||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
|
||||||
conf = json.loads(conf_str)
|
conf = json.loads(conf_str)
|
||||||
|
|
||||||
# 检查配置完整性,并插入
|
# 检查配置完整性,并插入
|
||||||
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
|
has_new = self.check_config_integrity(default_config, conf)
|
||||||
self.update(conf)
|
self.update(conf)
|
||||||
if has_new:
|
if has_new:
|
||||||
self.save_config()
|
self.save_config()
|
||||||
|
|
||||||
self.update(conf)
|
self.update(conf)
|
||||||
|
|
||||||
|
def _config_schema_to_default_config(self, schema: dict) -> dict:
|
||||||
|
"""将 Schema 转换成 Config"""
|
||||||
|
conf = {}
|
||||||
|
|
||||||
|
def _parse_schema(schema: dict, conf: dict):
|
||||||
|
for k, v in schema.items():
|
||||||
|
if v["type"] not in DEFAULT_VALUE_MAP:
|
||||||
|
raise TypeError(
|
||||||
|
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
|
||||||
|
)
|
||||||
|
if "default" in v:
|
||||||
|
default = v["default"]
|
||||||
|
else:
|
||||||
|
default = DEFAULT_VALUE_MAP[v["type"]]
|
||||||
|
|
||||||
|
if v["type"] == "object":
|
||||||
|
conf[k] = {}
|
||||||
|
_parse_schema(v["items"], conf[k])
|
||||||
|
else:
|
||||||
|
conf[k] = default
|
||||||
|
|
||||||
|
_parse_schema(schema, conf)
|
||||||
|
|
||||||
|
return conf
|
||||||
|
|
||||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||||
'''检查配置完整性,如果有新的配置项则返回 True'''
|
"""检查配置完整性,如果有新的配置项则返回 True"""
|
||||||
has_new = False
|
has_new = False
|
||||||
for key, value in refer_conf.items():
|
for key, value in refer_conf.items():
|
||||||
if key not in conf:
|
if key not in conf:
|
||||||
@@ -51,17 +97,19 @@ class AstrBotConfig(dict):
|
|||||||
conf[key] = value
|
conf[key] = value
|
||||||
has_new = True
|
has_new = True
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key)
|
has_new |= self.check_config_integrity(
|
||||||
|
value, conf[key], path + "." + key if path else key
|
||||||
|
)
|
||||||
return has_new
|
return has_new
|
||||||
|
|
||||||
def save_config(self, replace_config: Dict = None):
|
def save_config(self, replace_config: Dict = None):
|
||||||
'''将配置写入文件
|
"""将配置写入文件
|
||||||
|
|
||||||
如果传入 replace_config,则将配置替换为 replace_config
|
如果传入 replace_config,则将配置替换为 replace_config
|
||||||
'''
|
"""
|
||||||
if replace_config:
|
if replace_config:
|
||||||
self.update(replace_config)
|
self.update(replace_config)
|
||||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
with open(self.config_path, "w", encoding="utf-8-sig") as f:
|
||||||
json.dump(self, f, indent=2, ensure_ascii=False)
|
json.dump(self, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
@@ -81,4 +129,4 @@ class AstrBotConfig(dict):
|
|||||||
self[key] = value
|
self[key] = value
|
||||||
|
|
||||||
def check_exist(self) -> bool:
|
def check_exist(self) -> bool:
|
||||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
return os.path.exists(self.config_path)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
VERSION = "3.4.12"
|
VERSION = "3.5.3.1"
|
||||||
DB_PATH = "data/data_v3.db"
|
DB_PATH = "data/data_v3.db"
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -16,7 +16,7 @@ DEFAULT_CONFIG = {
|
|||||||
"strategy": "stall", # stall, discard
|
"strategy": "stall", # stall, discard
|
||||||
},
|
},
|
||||||
"reply_prefix": "",
|
"reply_prefix": "",
|
||||||
"forward_threshold": 200,
|
"forward_threshold": 1500,
|
||||||
"enable_id_white_list": True,
|
"enable_id_white_list": True,
|
||||||
"id_whitelist": [],
|
"id_whitelist": [],
|
||||||
"id_whitelist_log": True,
|
"id_whitelist_log": True,
|
||||||
@@ -24,17 +24,34 @@ DEFAULT_CONFIG = {
|
|||||||
"wl_ignore_admin_on_friend": True,
|
"wl_ignore_admin_on_friend": True,
|
||||||
"reply_with_mention": False,
|
"reply_with_mention": False,
|
||||||
"reply_with_quote": False,
|
"reply_with_quote": False,
|
||||||
"path_mapping": []
|
"path_mapping": [],
|
||||||
|
"segmented_reply": {
|
||||||
|
"enable": False,
|
||||||
|
"only_llm_result": True,
|
||||||
|
"interval_method": "random",
|
||||||
|
"interval": "1.5,3.5",
|
||||||
|
"log_base": 2.6,
|
||||||
|
"words_count_threshold": 150,
|
||||||
|
"regex": ".*?[。?!~…]+|.+$",
|
||||||
|
"content_cleanup_rule": "",
|
||||||
|
},
|
||||||
|
"no_permission_reply": True,
|
||||||
|
"empty_mention_waiting": True,
|
||||||
|
"friend_message_needs_wake_prefix": False,
|
||||||
},
|
},
|
||||||
"provider": [],
|
"provider": [],
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"wake_prefix": "",
|
"wake_prefix": "",
|
||||||
"web_search": False,
|
"web_search": False,
|
||||||
|
"web_search_link": False,
|
||||||
"identifier": False,
|
"identifier": False,
|
||||||
"datetime_system_prompt": True,
|
"datetime_system_prompt": True,
|
||||||
"default_personality": "default",
|
"default_personality": "default",
|
||||||
"prompt_prefix": "",
|
"prompt_prefix": "",
|
||||||
|
"max_context_length": -1,
|
||||||
|
"dequeue_context_length": 1,
|
||||||
|
"streaming_response": False,
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -44,33 +61,46 @@ DEFAULT_CONFIG = {
|
|||||||
"enable": False,
|
"enable": False,
|
||||||
"provider_id": "",
|
"provider_id": "",
|
||||||
},
|
},
|
||||||
|
"provider_ltm_settings": {
|
||||||
|
"group_icl_enable": False,
|
||||||
|
"group_message_max_cnt": 300,
|
||||||
|
"image_caption": False,
|
||||||
|
"image_caption_provider_id": "",
|
||||||
|
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||||
|
"active_reply": {
|
||||||
|
"enable": False,
|
||||||
|
"method": "possibility_reply",
|
||||||
|
"possibility_reply": 0.1,
|
||||||
|
"prompt": "",
|
||||||
|
"whitelist": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
"content_safety": {
|
"content_safety": {
|
||||||
|
"also_use_in_response": False,
|
||||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||||
},
|
},
|
||||||
"admins_id": [],
|
"admins_id": ["astrbot"],
|
||||||
"t2i": False,
|
"t2i": False,
|
||||||
|
"t2i_word_threshold": 150,
|
||||||
|
"t2i_strategy": "remote",
|
||||||
|
"t2i_endpoint": "",
|
||||||
"http_proxy": "",
|
"http_proxy": "",
|
||||||
"dashboard": {
|
"dashboard": {
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"username": "astrbot",
|
"username": "astrbot",
|
||||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 6185,
|
||||||
},
|
},
|
||||||
"platform": [],
|
"platform": [],
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
"t2i_endpoint": "",
|
|
||||||
"pip_install_arg": "",
|
"pip_install_arg": "",
|
||||||
"plugin_repo_mirror": "",
|
"plugin_repo_mirror": "",
|
||||||
"knowledge_db": {},
|
"knowledge_db": {},
|
||||||
"persona": [
|
"persona": [],
|
||||||
{
|
"timezone": "",
|
||||||
"name": "default",
|
|
||||||
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
|
||||||
"begin_dialogs": [],
|
|
||||||
"mood_imitation_dialogs": [],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -92,29 +122,80 @@ CONFIG_METADATA_2 = {
|
|||||||
"enable_group_c2c": True,
|
"enable_group_c2c": True,
|
||||||
"enable_guild_direct_message": True,
|
"enable_guild_direct_message": True,
|
||||||
},
|
},
|
||||||
"aiocqhtp(QQ)": {
|
"qq_official_webhook(QQ)": {
|
||||||
|
"id": "default",
|
||||||
|
"type": "qq_official_webhook",
|
||||||
|
"enable": False,
|
||||||
|
"appid": "",
|
||||||
|
"secret": "",
|
||||||
|
"callback_server_host": "0.0.0.0",
|
||||||
|
"port": 6196,
|
||||||
|
},
|
||||||
|
"aiocqhttp(OneBotv11)": {
|
||||||
"id": "default",
|
"id": "default",
|
||||||
"type": "aiocqhttp",
|
"type": "aiocqhttp",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"ws_reverse_host": "",
|
"ws_reverse_host": "0.0.0.0",
|
||||||
"ws_reverse_port": 6199,
|
"ws_reverse_port": 6199,
|
||||||
},
|
},
|
||||||
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
|
|
||||||
"gewechat(微信)": {
|
"gewechat(微信)": {
|
||||||
"id": "gwchat",
|
"id": "gwchat",
|
||||||
"type": "gewechat",
|
"type": "gewechat",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"base_url": "http://localhost:2531",
|
"base_url": "http://localhost:2531",
|
||||||
"nickname": "soulter",
|
"nickname": "soulter",
|
||||||
"host": "localhost",
|
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||||
"port": 11451,
|
"port": 11451,
|
||||||
},
|
},
|
||||||
|
"wecom(企业微信)": {
|
||||||
|
"id": "wecom",
|
||||||
|
"type": "wecom",
|
||||||
|
"enable": False,
|
||||||
|
"corpid": "",
|
||||||
|
"secret": "",
|
||||||
|
"token": "",
|
||||||
|
"encoding_aes_key": "",
|
||||||
|
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||||
|
"callback_server_host": "0.0.0.0",
|
||||||
|
"port": 6195,
|
||||||
|
},
|
||||||
|
"lark(飞书)": {
|
||||||
|
"id": "lark",
|
||||||
|
"type": "lark",
|
||||||
|
"enable": False,
|
||||||
|
"lark_bot_name": "",
|
||||||
|
"app_id": "",
|
||||||
|
"app_secret": "",
|
||||||
|
"domain": "https://open.feishu.cn",
|
||||||
|
},
|
||||||
|
"dingtalk(钉钉)": {
|
||||||
|
"id": "dingtalk",
|
||||||
|
"type": "dingtalk",
|
||||||
|
"enable": False,
|
||||||
|
"client_id": "",
|
||||||
|
"client_secret": "",
|
||||||
|
},
|
||||||
|
"telegram": {
|
||||||
|
"id": "telegram",
|
||||||
|
"type": "telegram",
|
||||||
|
"enable": False,
|
||||||
|
"telegram_token": "your_bot_token",
|
||||||
|
"start_message": "Hello, I'm AstrBot!",
|
||||||
|
"telegram_api_base_url": "https://api.telegram.org/bot",
|
||||||
|
"telegram_file_base_url": "https://api.telegram.org/file/bot",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
"telegram_token": {
|
||||||
|
"description": "Bot Token",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||||
|
},
|
||||||
"id": {
|
"id": {
|
||||||
"description": "ID",
|
"description": "ID",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。",
|
"obvious_hint": True,
|
||||||
|
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"description": "适配器类型",
|
"description": "适配器类型",
|
||||||
@@ -147,7 +228,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
||||||
},
|
},
|
||||||
"ws_reverse_host": {
|
"ws_reverse_host": {
|
||||||
"description": "反向 Websocket 主机地址",
|
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
||||||
},
|
},
|
||||||
@@ -156,12 +237,21 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||||
},
|
},
|
||||||
|
"lark_bot_name": {
|
||||||
|
"description": "飞书机器人的名字",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||||
|
"obvious_hint": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"platform_settings": {
|
"platform_settings": {
|
||||||
"description": "平台设置",
|
"description": "平台设置",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
|
"plugin_enable": {
|
||||||
|
"invisible": True, # 隐藏插件启用配置
|
||||||
|
},
|
||||||
"unique_session": {
|
"unique_session": {
|
||||||
"description": "会话隔离",
|
"description": "会话隔离",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -182,6 +272,68 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"no_permission_reply": {
|
||||||
|
"description": "无权限回复",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
|
||||||
|
},
|
||||||
|
"empty_mention_waiting": {
|
||||||
|
"description": "只 @ 机器人是否触发等待回复",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待回复,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
|
||||||
|
},
|
||||||
|
"friend_message_needs_wake_prefix": {
|
||||||
|
"description": "私聊消息是否需要唤醒前缀",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,私聊消息需要唤醒前缀才会被处理,同群聊一样。",
|
||||||
|
},
|
||||||
|
"segmented_reply": {
|
||||||
|
"description": "分段回复",
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"enable": {
|
||||||
|
"description": "启用分段回复",
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
|
"only_llm_result": {
|
||||||
|
"description": "仅对 LLM 结果分段",
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
|
"interval_method": {
|
||||||
|
"description": "间隔时间计算方法",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["random", "log"],
|
||||||
|
"hint": "分段回复的间隔时间计算方法。random 为随机时间,log 为根据消息长度计算,$y=log_<log_base>(x)$,x为字数,y的单位为秒。",
|
||||||
|
},
|
||||||
|
"interval": {
|
||||||
|
"description": "随机间隔时间(秒)",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "`random` 方法用。每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
|
||||||
|
},
|
||||||
|
"log_base": {
|
||||||
|
"description": "对数函数底数",
|
||||||
|
"type": "float",
|
||||||
|
"hint": "`log` 方法用。对数函数的底数。默认为 2.6",
|
||||||
|
},
|
||||||
|
"words_count_threshold": {
|
||||||
|
"description": "字数阈值",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超过这个字数的消息不会被分段回复。默认为 150",
|
||||||
|
},
|
||||||
|
"regex": {
|
||||||
|
"description": "正则表达式",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
|
||||||
|
},
|
||||||
|
"content_cleanup_rule": {
|
||||||
|
"description": "过滤分段后的内容",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"reply_prefix": {
|
"reply_prefix": {
|
||||||
"description": "回复前缀",
|
"description": "回复前缀",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -199,8 +351,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"id_whitelist": {
|
"id_whitelist": {
|
||||||
"description": "ID 白名单",
|
"description": "ID 白名单",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "int"},
|
"items": {"type": "string"},
|
||||||
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
"obvious_hint": True,
|
||||||
|
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
|
||||||
},
|
},
|
||||||
"id_whitelist_log": {
|
"id_whitelist_log": {
|
||||||
"description": "打印白名单日志",
|
"description": "打印白名单日志",
|
||||||
@@ -228,15 +381,21 @@ CONFIG_METADATA_2 = {
|
|||||||
"path_mapping": {
|
"path_mapping": {
|
||||||
"description": "路径映射",
|
"description": "路径映射",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
"obvious_hint": True,
|
"obvious_hint": True,
|
||||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"content_safety": {
|
"content_safety": {
|
||||||
"description": "内容安全",
|
"description": "内容安全",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
|
"also_use_in_response": {
|
||||||
|
"description": "对大模型响应安全审核",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,大模型的响应也会通过内容安全审核。",
|
||||||
|
},
|
||||||
"baidu_aip": {
|
"baidu_aip": {
|
||||||
"description": "百度内容审核配置",
|
"description": "百度内容审核配置",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -281,17 +440,53 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "服务提供商配置",
|
"description": "服务提供商配置",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"config_template": {
|
"config_template": {
|
||||||
"openai": {
|
"OpenAI": {
|
||||||
"id": "default",
|
"id": "openai",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "",
|
"api_base": "https://api.openai.com/v1",
|
||||||
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ollama": {
|
"Azure_OpenAI": {
|
||||||
|
"id": "azure",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"api_version": "2024-05-01-preview",
|
||||||
|
"key": [],
|
||||||
|
"api_base": "",
|
||||||
|
"timeout": 120,
|
||||||
|
"model_config": {
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"xAI(grok)": {
|
||||||
|
"id": "xai",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"api_base": "https://api.x.ai/v1",
|
||||||
|
"timeout": 120,
|
||||||
|
"model_config": {
|
||||||
|
"model": "grok-2-latest",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Anthropic(claude)": {
|
||||||
|
"id": "claude",
|
||||||
|
"type": "anthropic_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"api_base": "https://api.anthropic.com/v1",
|
||||||
|
"timeout": 120,
|
||||||
|
"model_config": {
|
||||||
|
"model": "claude-3-5-sonnet-latest",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Ollama": {
|
||||||
"id": "ollama_default",
|
"id": "ollama_default",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -301,47 +496,90 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "llama3.1-8b",
|
"model": "llama3.1-8b",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"gemini(OpenAI兼容)": {
|
"LM_Studio": {
|
||||||
|
"id": "lm_studio",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": ["lmstudio"],
|
||||||
|
"api_base": "http://localhost:1234/v1",
|
||||||
|
"model_config": {
|
||||||
|
"model": "llama-3.1-8b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Gemini(OpenAI兼容)": {
|
||||||
"id": "gemini_default",
|
"id": "gemini_default",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-1.5-flash",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"gemini(googlegenai原生)": {
|
"Gemini(googlegenai原生)": {
|
||||||
"id": "gemini_default",
|
"id": "gemini_default",
|
||||||
"type": "googlegenai_chat_completion",
|
"type": "googlegenai_chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://generativelanguage.googleapis.com/",
|
"api_base": "https://generativelanguage.googleapis.com/",
|
||||||
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-2.0-flash-exp",
|
||||||
|
},
|
||||||
|
"gm_resp_image_modal": False,
|
||||||
|
"gm_safety_settings": {
|
||||||
|
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"deepseek": {
|
"DeepSeek": {
|
||||||
"id": "deepseek_default",
|
"id": "deepseek_default",
|
||||||
"type": "openai_chat_completion",
|
"type": "openai_chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://api.deepseek.com/v1",
|
"api_base": "https://api.deepseek.com/v1",
|
||||||
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "deepseek-chat",
|
"model": "deepseek-chat",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"zhipu": {
|
"Zhipu(智谱)": {
|
||||||
"id": "zhipu_default",
|
"id": "zhipu_default",
|
||||||
"type": "zhipu_chat_completion",
|
"type": "zhipu_chat_completion",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"key": [],
|
"key": [],
|
||||||
|
"timeout": 120,
|
||||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "glm-4-flash",
|
"model": "glm-4-flash",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"llmtuner": {
|
"SiliconFlow(硅基流动)": {
|
||||||
|
"id": "siliconflow",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"timeout": 120,
|
||||||
|
"api_base": "https://api.siliconflow.cn/v1",
|
||||||
|
"model_config": {
|
||||||
|
"model": "deepseek-ai/DeepSeek-V3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"MoonShot(Kimi)": {
|
||||||
|
"id": "moonshot",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"timeout": 120,
|
||||||
|
"api_base": "https://api.moonshot.cn/v1",
|
||||||
|
"model_config": {
|
||||||
|
"model": "moonshot-v1-8k",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"LLMTuner": {
|
||||||
"id": "llmtuner_default",
|
"id": "llmtuner_default",
|
||||||
"type": "llm_tuner",
|
"type": "llm_tuner",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
@@ -351,16 +589,42 @@ CONFIG_METADATA_2 = {
|
|||||||
"finetuning_type": "lora",
|
"finetuning_type": "lora",
|
||||||
"quantization_bit": 4,
|
"quantization_bit": 4,
|
||||||
},
|
},
|
||||||
"dify": {
|
"Dify": {
|
||||||
"id": "dify_app_default",
|
"id": "dify_app_default",
|
||||||
"type": "dify",
|
"type": "dify",
|
||||||
"enable": True,
|
"enable": True,
|
||||||
"dify_api_type": "chat",
|
"dify_api_type": "chat",
|
||||||
"dify_api_key": "",
|
"dify_api_key": "",
|
||||||
"dify_api_base": "https://api.dify.ai/v1",
|
"dify_api_base": "https://api.dify.ai/v1",
|
||||||
"dify_workflow_output_key": "",
|
"dify_workflow_output_key": "astrbot_wf_output",
|
||||||
|
"dify_query_input_key": "astrbot_text_query",
|
||||||
|
"variables": {},
|
||||||
|
"timeout": 60,
|
||||||
},
|
},
|
||||||
"whisper(API)": {
|
"Dashscope(阿里云百炼应用)": {
|
||||||
|
"id": "dashscope",
|
||||||
|
"type": "dashscope",
|
||||||
|
"enable": True,
|
||||||
|
"dashscope_app_type": "agent",
|
||||||
|
"dashscope_api_key": "",
|
||||||
|
"dashscope_app_id": "",
|
||||||
|
"rag_options": {
|
||||||
|
"pipeline_ids": [],
|
||||||
|
"file_ids": [],
|
||||||
|
"output_reference": False,
|
||||||
|
},
|
||||||
|
"variables": {},
|
||||||
|
"timeout": 60,
|
||||||
|
},
|
||||||
|
"FastGPT": {
|
||||||
|
"id": "fastgpt",
|
||||||
|
"type": "openai_chat_completion",
|
||||||
|
"enable": True,
|
||||||
|
"key": [],
|
||||||
|
"api_base": "https://api.fastgpt.in/api/v1",
|
||||||
|
"timeout": 60,
|
||||||
|
},
|
||||||
|
"Whisper(API)": {
|
||||||
"id": "whisper",
|
"id": "whisper",
|
||||||
"type": "openai_whisper_api",
|
"type": "openai_whisper_api",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -368,23 +632,213 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "",
|
"api_base": "",
|
||||||
"model": "whisper-1",
|
"model": "whisper-1",
|
||||||
},
|
},
|
||||||
"whisper(本地加载)": {
|
"Whisper(本地加载)": {
|
||||||
"whisper_hint": "(不用修改我)",
|
"whisper_hint": "(不用修改我)",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"id": "whisper",
|
"id": "whisper",
|
||||||
"type": "openai_whisper_selfhost",
|
"type": "openai_whisper_selfhost",
|
||||||
"model": "tiny",
|
"model": "tiny",
|
||||||
},
|
},
|
||||||
"openai_tts(API)": {
|
"sensevoice(本地加载)": {
|
||||||
|
"sensevoice_hint": "(不用修改我)",
|
||||||
|
"enable": False,
|
||||||
|
"id": "sensevoice",
|
||||||
|
"type": "sensevoice_stt_selfhost",
|
||||||
|
"stt_model": "iic/SenseVoiceSmall",
|
||||||
|
"is_emotion": False,
|
||||||
|
},
|
||||||
|
"OpenAI_TTS(API)": {
|
||||||
"id": "openai_tts",
|
"id": "openai_tts",
|
||||||
"type": "openai_tts_api",
|
"type": "openai_tts_api",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"api_base": "",
|
"api_base": "",
|
||||||
"model": "tts-1",
|
"model": "tts-1",
|
||||||
|
"openai-tts-voice": "alloy",
|
||||||
|
"timeout": "20",
|
||||||
|
},
|
||||||
|
"Edge_TTS": {
|
||||||
|
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||||
|
"id": "edge_tts",
|
||||||
|
"type": "edge_tts",
|
||||||
|
"enable": False,
|
||||||
|
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
|
||||||
|
"timeout": 20,
|
||||||
|
},
|
||||||
|
"GSVI_TTS(API)": {
|
||||||
|
"id": "gsvi_tts",
|
||||||
|
"type": "gsvi_tts_api",
|
||||||
|
"api_base": "http://127.0.0.1:5000",
|
||||||
|
"character": "",
|
||||||
|
"emotion": "default",
|
||||||
|
"enable": False,
|
||||||
|
"timeout": 20,
|
||||||
|
},
|
||||||
|
"FishAudio_TTS(API)": {
|
||||||
|
"id": "fishaudio_tts",
|
||||||
|
"type": "fishaudio_tts_api",
|
||||||
|
"enable": False,
|
||||||
|
"api_key": "",
|
||||||
|
"api_base": "https://api.fish.audio/v1",
|
||||||
|
"fishaudio-tts-character": "可莉",
|
||||||
|
"timeout": "20",
|
||||||
|
},
|
||||||
|
"阿里云百炼_TTS(API)": {
|
||||||
|
"id": "dashscope_tts",
|
||||||
|
"type": "dashscope_tts",
|
||||||
|
"enable": False,
|
||||||
|
"api_key": "",
|
||||||
|
"model": "cosyvoice-v1",
|
||||||
|
"dashscope_tts_voice": "loongstella",
|
||||||
|
"timeout": "20",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
"dashscope_tts_voice": {
|
||||||
|
"description": "语音合成模型",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
|
||||||
|
},
|
||||||
|
"gm_resp_image_modal": {
|
||||||
|
"description": "启用图片模态",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
|
||||||
|
},
|
||||||
|
"gm_safety_settings": {
|
||||||
|
"description": "安全过滤器",
|
||||||
|
"type": "object",
|
||||||
|
"hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。",
|
||||||
|
"items": {
|
||||||
|
"harassment": {
|
||||||
|
"description": "骚扰内容",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "负面或有害评论",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"hate_speech": {
|
||||||
|
"description": "仇恨言论",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "粗鲁、无礼或亵渎性质内容",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"sexually_explicit": {
|
||||||
|
"description": "露骨色情内容",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "包含性行为或其他淫秽内容的引用",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"dangerous_content": {
|
||||||
|
"description": "危险内容",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "宣扬、助长或鼓励有害行为的信息",
|
||||||
|
"options": [
|
||||||
|
"BLOCK_NONE",
|
||||||
|
"BLOCK_ONLY_HIGH",
|
||||||
|
"BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"BLOCK_LOW_AND_ABOVE",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"rag_options": {
|
||||||
|
"description": "RAG 选项",
|
||||||
|
"type": "object",
|
||||||
|
"hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)。阿里云百炼应用开启此功能后将无法多轮对话。",
|
||||||
|
"items": {
|
||||||
|
"pipeline_ids": {
|
||||||
|
"description": "知识库 ID 列表",
|
||||||
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。",
|
||||||
|
},
|
||||||
|
"file_ids": {
|
||||||
|
"description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。",
|
||||||
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。",
|
||||||
|
},
|
||||||
|
"output_reference": {
|
||||||
|
"description": "是否输出知识库/文档的引用",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "在每次回答尾部加上引用源。默认为 False。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sensevoice_hint": {
|
||||||
|
"description": "部署SenseVoice",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||||
|
"obvious_hint": True,
|
||||||
|
},
|
||||||
|
"is_emotion": {
|
||||||
|
"description": "情绪识别",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "是否开启情绪识别。happy|sad|angry|neutral|fearful|disgusted|surprised|unknown",
|
||||||
|
},
|
||||||
|
"stt_model": {
|
||||||
|
"description": "模型名称",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。",
|
||||||
|
},
|
||||||
|
"variables": {
|
||||||
|
"description": "工作流固定输入变量",
|
||||||
|
"type": "object",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"items": {},
|
||||||
|
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
|
||||||
|
"invisible": True,
|
||||||
|
},
|
||||||
|
# "fastgpt_app_type": {
|
||||||
|
# "description": "应用类型",
|
||||||
|
# "type": "string",
|
||||||
|
# "hint": "FastGPT 应用的应用类型。",
|
||||||
|
# "options": ["agent", "workflow", "plugin"],
|
||||||
|
# "obvious_hint": True,
|
||||||
|
# },
|
||||||
|
"dashscope_app_type": {
|
||||||
|
"description": "应用类型",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "百炼应用的应用类型。",
|
||||||
|
"options": [
|
||||||
|
"agent",
|
||||||
|
"agent-arrange",
|
||||||
|
"dialog-workflow",
|
||||||
|
"task-workflow",
|
||||||
|
],
|
||||||
|
"obvious_hint": True,
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"description": "超时时间",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超时时间,单位为秒。",
|
||||||
|
},
|
||||||
|
"openai-tts-voice": {
|
||||||
|
"description": "voice",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||||
|
},
|
||||||
|
"fishaudio-tts-character": {
|
||||||
|
"description": "character",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
||||||
|
},
|
||||||
"whisper_hint": {
|
"whisper_hint": {
|
||||||
"description": "本地部署 Whisper 模型须知",
|
"description": "本地部署 Whisper 模型须知",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -394,7 +848,8 @@ CONFIG_METADATA_2 = {
|
|||||||
"id": {
|
"id": {
|
||||||
"description": "ID",
|
"description": "ID",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。",
|
"obvious_hint": True,
|
||||||
|
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"description": "模型提供商类型",
|
"description": "模型提供商类型",
|
||||||
@@ -415,7 +870,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": {
|
"api_base": {
|
||||||
"description": "API Base URL",
|
"description": "API Base URL",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。",
|
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||||
"obvious_hint": True,
|
"obvious_hint": True,
|
||||||
},
|
},
|
||||||
"base_model_path": {
|
"base_model_path": {
|
||||||
@@ -473,14 +928,20 @@ CONFIG_METADATA_2 = {
|
|||||||
"dify_api_type": {
|
"dify_api_type": {
|
||||||
"description": "Dify 应用类型",
|
"description": "Dify 应用类型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, chatflow, agent, workflow 三种应用类型。",
|
||||||
"options": ["chat", "agent", "workflow"],
|
"options": ["chat", "chatflow", "agent", "workflow"],
|
||||||
},
|
},
|
||||||
"dify_workflow_output_key": {
|
"dify_workflow_output_key": {
|
||||||
"description": "Dify Workflow 输出变量名",
|
"description": "Dify Workflow 输出变量名",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
||||||
},
|
},
|
||||||
|
"dify_query_input_key": {
|
||||||
|
"description": "Prompt 输入变量名",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||||
|
"obvious": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
@@ -501,16 +962,25 @@ CONFIG_METADATA_2 = {
|
|||||||
"web_search": {
|
"web_search": {
|
||||||
"description": "启用网页搜索",
|
"description": "启用网页搜索",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
"obvious_hint": True,
|
||||||
|
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||||
|
},
|
||||||
|
"web_search_link": {
|
||||||
|
"description": "网页搜索引用链接",
|
||||||
|
"type": "bool",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||||
},
|
},
|
||||||
"identifier": {
|
"identifier": {
|
||||||
"description": "启动识别群员",
|
"description": "启动识别群员",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
|
"obvious_hint": True,
|
||||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||||
},
|
},
|
||||||
"datetime_system_prompt": {
|
"datetime_system_prompt": {
|
||||||
"description": "启用日期时间系统提示",
|
"description": "启用日期时间系统提示",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
|
"obvious_hint": True,
|
||||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||||
},
|
},
|
||||||
"default_personality": {
|
"default_personality": {
|
||||||
@@ -523,6 +993,21 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
|
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
|
||||||
},
|
},
|
||||||
|
"max_context_length": {
|
||||||
|
"description": "最多携带对话数量(条)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
|
||||||
|
},
|
||||||
|
"dequeue_context_length": {
|
||||||
|
"description": "丢弃对话数量(条)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超出 最多携带对话数量(条) 时,丢弃多少条记录,用户和AI的一轮聊天记为 1 条。适宜的配置,可以提高超长上下文对话 deepseek 命中缓存效果,理想情况下计费将降低到1/3以下",
|
||||||
|
},
|
||||||
|
"streaming_response": {
|
||||||
|
"description": "启用流式回复",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"persona": {
|
"persona": {
|
||||||
@@ -552,15 +1037,15 @@ CONFIG_METADATA_2 = {
|
|||||||
"begin_dialogs": {
|
"begin_dialogs": {
|
||||||
"description": "预设对话",
|
"description": "预设对话",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {},
|
"items": {"type": "string"},
|
||||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||||
"obvious_hint": True,
|
"obvious_hint": True,
|
||||||
},
|
},
|
||||||
"mood_imitation_dialogs": {
|
"mood_imitation_dialogs": {
|
||||||
"description": "对话风格模仿",
|
"description": "对话风格模仿",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {},
|
"items": {"type": "string"},
|
||||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||||
"obvious_hint": True,
|
"obvious_hint": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -599,6 +1084,77 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"provider_ltm_settings": {
|
||||||
|
"description": "聊天记忆增强(Beta)",
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"group_icl_enable": {
|
||||||
|
"description": "群聊内记录各群员对话",
|
||||||
|
"type": "bool",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||||
|
},
|
||||||
|
"group_message_max_cnt": {
|
||||||
|
"description": "群聊消息最大数量",
|
||||||
|
"type": "int",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||||
|
},
|
||||||
|
"image_caption": {
|
||||||
|
"description": "群聊图像转述(需模型支持)",
|
||||||
|
"type": "bool",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
|
||||||
|
},
|
||||||
|
"image_caption_provider_id": {
|
||||||
|
"description": "图像转述提供商 ID",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
|
||||||
|
},
|
||||||
|
"image_caption_prompt": {
|
||||||
|
"description": "图像转述提示词",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"active_reply": {
|
||||||
|
"description": "主动回复",
|
||||||
|
"type": "object",
|
||||||
|
"items": {
|
||||||
|
"enable": {
|
||||||
|
"description": "启用主动回复",
|
||||||
|
"type": "bool",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
|
||||||
|
},
|
||||||
|
"whitelist": {
|
||||||
|
"description": "主动回复白名单",
|
||||||
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"description": "回复方法",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["possibility_reply"],
|
||||||
|
"hint": "回复方法。possibility_reply 为根据概率回复",
|
||||||
|
},
|
||||||
|
"possibility_reply": {
|
||||||
|
"description": "回复概率",
|
||||||
|
"type": "float",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"description": "提示词",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"misc_config_group": {
|
"misc_config_group": {
|
||||||
@@ -608,34 +1164,52 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "机器人唤醒前缀",
|
"description": "机器人唤醒前缀",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
|
"obvious_hint": True,
|
||||||
|
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
|
||||||
},
|
},
|
||||||
"t2i": {
|
"t2i": {
|
||||||
"description": "文本转图像",
|
"description": "文本转图像",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。",
|
"hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。",
|
||||||
},
|
},
|
||||||
|
"t2i_word_threshold": {
|
||||||
|
"description": "文本转图像字数阈值",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超出此字符长度的文本将会被转换成图片。字数不能低于 50。",
|
||||||
|
},
|
||||||
"admins_id": {
|
"admins_id": {
|
||||||
"description": "管理员 ID",
|
"description": "管理员 ID",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "int"},
|
"items": {"type": "string"},
|
||||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
|
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
|
||||||
},
|
},
|
||||||
"http_proxy": {
|
"http_proxy": {
|
||||||
"description": "HTTP 代理",
|
"description": "HTTP 代理",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
||||||
},
|
},
|
||||||
|
"timezone": {
|
||||||
|
"description": "时区",
|
||||||
|
"type": "string",
|
||||||
|
"obvious_hint": True,
|
||||||
|
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||||
|
},
|
||||||
"log_level": {
|
"log_level": {
|
||||||
"description": "控制台日志级别",
|
"description": "控制台日志级别",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "控制台输出日志的级别。",
|
"hint": "控制台输出日志的级别。",
|
||||||
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
},
|
},
|
||||||
|
"t2i_strategy": {
|
||||||
|
"description": "文本转图像渲染源",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。",
|
||||||
|
"options": ["remote", "local"],
|
||||||
|
},
|
||||||
"t2i_endpoint": {
|
"t2i_endpoint": {
|
||||||
"description": "文本转图像服务接口",
|
"description": "文本转图像服务接口",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "为空时使用 AstrBot API 服务",
|
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
|
||||||
},
|
},
|
||||||
"pip_install_arg": {
|
"pip_install_arg": {
|
||||||
"description": "pip 安装参数",
|
"description": "pip 安装参数",
|
||||||
@@ -645,7 +1219,8 @@ CONFIG_METADATA_2 = {
|
|||||||
"plugin_repo_mirror": {
|
"plugin_repo_mirror": {
|
||||||
"description": "插件仓库镜像",
|
"description": "插件仓库镜像",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "插件仓库的镜像地址,用于加速插件的下载。",
|
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
|
||||||
|
"obvious_hint": True,
|
||||||
"options": [
|
"options": [
|
||||||
"default",
|
"default",
|
||||||
"https://ghp.ci/",
|
"https://ghp.ci/",
|
||||||
|
|||||||
199
astrbot/core/conversation_mgr.py
Normal file
199
astrbot/core/conversation_mgr.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
|
||||||
|
|
||||||
|
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||||||
|
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from astrbot.core import sp
|
||||||
|
from typing import Dict, List
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.db.po import Conversation
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationManager:
|
||||||
|
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||||
|
|
||||||
|
def __init__(self, db_helper: BaseDatabase):
|
||||||
|
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||||
|
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||||
|
self.db = db_helper
|
||||||
|
self.save_interval = 60 # 每 60 秒保存一次
|
||||||
|
self._start_periodic_save()
|
||||||
|
|
||||||
|
def _start_periodic_save(self):
|
||||||
|
"""启动定时保存任务"""
|
||||||
|
asyncio.create_task(self._periodic_save())
|
||||||
|
|
||||||
|
async def _periodic_save(self):
|
||||||
|
"""定时保存会话对话映射关系到存储中"""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self.save_interval)
|
||||||
|
self._save_to_storage()
|
||||||
|
|
||||||
|
def _save_to_storage(self):
|
||||||
|
"""保存会话对话映射关系到存储中"""
|
||||||
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
|
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||||
|
"""新建对话,并将当前会话的对话转移到新对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||||
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
return conversation_id
|
||||||
|
|
||||||
|
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||||
|
"""切换会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
|
async def delete_conversation(
|
||||||
|
self, unified_msg_origin: str, conversation_id: str = None
|
||||||
|
):
|
||||||
|
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
|
if conversation_id:
|
||||||
|
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||||
|
del self.session_conversations[unified_msg_origin]
|
||||||
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||||
|
"""获取会话当前的对话 ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
|
return self.session_conversations.get(unified_msg_origin, None)
|
||||||
|
|
||||||
|
async def get_conversation(
|
||||||
|
self, unified_msg_origin: str, conversation_id: str
|
||||||
|
) -> Conversation:
|
||||||
|
"""获取会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
Returns:
|
||||||
|
conversation (Conversation): 对话对象
|
||||||
|
"""
|
||||||
|
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||||
|
|
||||||
|
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||||
|
"""获取会话的所有对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversations (List[Conversation]): 对话对象列表
|
||||||
|
"""
|
||||||
|
return self.db.get_conversations(unified_msg_origin)
|
||||||
|
|
||||||
|
async def update_conversation(
|
||||||
|
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||||
|
):
|
||||||
|
"""更新会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||||
|
"""
|
||||||
|
if conversation_id:
|
||||||
|
self.db.update_conversation(
|
||||||
|
user_id=unified_msg_origin,
|
||||||
|
cid=conversation_id,
|
||||||
|
history=json.dumps(history),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||||
|
"""更新会话的对话标题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
title (str): 对话标题
|
||||||
|
"""
|
||||||
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
|
if conversation_id:
|
||||||
|
self.db.update_conversation_title(
|
||||||
|
user_id=unified_msg_origin, cid=conversation_id, title=title
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_conversation_persona_id(
|
||||||
|
self, unified_msg_origin: str, persona_id: str
|
||||||
|
):
|
||||||
|
"""更新会话的对话 Persona ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
persona_id (str): 对话 Persona ID
|
||||||
|
"""
|
||||||
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
|
if conversation_id:
|
||||||
|
self.db.update_conversation_persona_id(
|
||||||
|
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_human_readable_context(
|
||||||
|
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||||
|
):
|
||||||
|
"""获取人类可读的上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
page (int): 页码
|
||||||
|
page_size (int): 每页大小
|
||||||
|
"""
|
||||||
|
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||||
|
history = json.loads(conversation.history)
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
temp_contexts = []
|
||||||
|
for record in history:
|
||||||
|
if record["role"] == "user":
|
||||||
|
temp_contexts.append(f"User: {record['content']}")
|
||||||
|
elif record["role"] == "assistant":
|
||||||
|
if "content" in record and record["content"]:
|
||||||
|
temp_contexts.append(f"Assistant: {record['content']}")
|
||||||
|
elif "tool_calls" in record:
|
||||||
|
tool_calls_str = json.dumps(
|
||||||
|
record["tool_calls"], ensure_ascii=False
|
||||||
|
)
|
||||||
|
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||||
|
else:
|
||||||
|
temp_contexts.append("Assistant: [未知的内容]")
|
||||||
|
contexts.insert(0, temp_contexts)
|
||||||
|
temp_contexts = []
|
||||||
|
|
||||||
|
# 展平 contexts 列表
|
||||||
|
contexts = [item for sublist in contexts for item in sublist]
|
||||||
|
|
||||||
|
# 计算分页
|
||||||
|
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||||||
|
total_pages = len(contexts) // page_size
|
||||||
|
if len(contexts) % page_size != 0:
|
||||||
|
total_pages += 1
|
||||||
|
|
||||||
|
return paged_contexts, total_pages
|
||||||
@@ -1,3 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||||
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||||
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 初始化所有组件
|
||||||
|
2. 启动事件总线和任务, 所有任务都在这里运行
|
||||||
|
3. 执行启动完成事件钩子
|
||||||
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -18,102 +29,180 @@ from astrbot.core.updator import AstrBotUpdator
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.config.default import VERSION
|
from astrbot.core.config.default import VERSION
|
||||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||||
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
|
from astrbot.core.star.star_handler import star_map
|
||||||
|
|
||||||
|
|
||||||
class AstrBotCoreLifecycle:
|
class AstrBotCoreLifecycle:
|
||||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
"""
|
||||||
self.log_broker = log_broker
|
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||||
self.astrbot_config = astrbot_config
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||||
self.db = db
|
EventBus 等。
|
||||||
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
"""
|
||||||
|
|
||||||
if self.astrbot_config['http_proxy']:
|
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||||
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
|
self.log_broker = log_broker # 初始化日志代理
|
||||||
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
|
self.astrbot_config = astrbot_config # 初始化配置
|
||||||
|
self.db = db # 初始化数据库
|
||||||
|
|
||||||
|
# 根据环境变量设置代理
|
||||||
|
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
|
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
|
os.environ["no_proxy"] = "localhost"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
logger.info("AstrBot v"+ VERSION)
|
"""
|
||||||
if os.environ.get("TESTING", ""):
|
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||||
logger.setLevel("DEBUG")
|
"""
|
||||||
else:
|
|
||||||
logger.setLevel(self.astrbot_config['log_level'])
|
|
||||||
self.event_queue = Queue()
|
|
||||||
self.event_queue.closed = False
|
|
||||||
|
|
||||||
|
# 初始化日志代理
|
||||||
|
logger.info("AstrBot v" + VERSION)
|
||||||
|
if os.environ.get("TESTING", ""):
|
||||||
|
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
|
||||||
|
else:
|
||||||
|
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||||
|
|
||||||
|
# 初始化事件队列
|
||||||
|
self.event_queue = Queue()
|
||||||
|
|
||||||
|
# 初始化供应商管理器
|
||||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||||
|
|
||||||
|
# 初始化平台管理器
|
||||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||||
|
|
||||||
|
# 初始化知识库管理器
|
||||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||||
|
|
||||||
|
# 初始化对话管理器
|
||||||
|
self.conversation_manager = ConversationManager(self.db)
|
||||||
|
|
||||||
|
# 初始化提供给插件的上下文
|
||||||
self.star_context = Context(
|
self.star_context = Context(
|
||||||
self.event_queue,
|
self.event_queue,
|
||||||
self.astrbot_config,
|
self.astrbot_config,
|
||||||
self.db,
|
self.db,
|
||||||
self.provider_manager,
|
self.provider_manager,
|
||||||
self.platform_manager,
|
self.platform_manager,
|
||||||
self.knowledge_db_manager
|
self.conversation_manager,
|
||||||
|
self.knowledge_db_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 初始化插件管理器
|
||||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||||
|
|
||||||
|
# 扫描、注册插件、实例化插件类
|
||||||
await self.plugin_manager.reload()
|
await self.plugin_manager.reload()
|
||||||
'''扫描、注册插件、实例化插件类'''
|
|
||||||
|
|
||||||
|
# 根据配置实例化各个 Provider
|
||||||
await self.provider_manager.initialize()
|
await self.provider_manager.initialize()
|
||||||
'''根据配置实例化各个 Provider'''
|
|
||||||
|
|
||||||
await self.platform_manager.initialize()
|
# 初始化消息事件流水线调度器
|
||||||
'''根据配置实例化各个平台适配器'''
|
self.pipeline_scheduler = PipelineScheduler(
|
||||||
|
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||||
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
|
)
|
||||||
await self.pipeline_scheduler.initialize()
|
await self.pipeline_scheduler.initialize()
|
||||||
'''初始化消息事件流水线调度器'''
|
|
||||||
|
|
||||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror'])
|
# 初始化更新器
|
||||||
|
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
||||||
|
|
||||||
|
# 初始化事件总线
|
||||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||||
|
|
||||||
|
# 记录启动时间
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
|
# 初始化当前任务列表
|
||||||
self.curr_tasks: List[asyncio.Task] = []
|
self.curr_tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
|
# 根据配置实例化各个平台适配器
|
||||||
|
await self.platform_manager.initialize()
|
||||||
|
|
||||||
|
# 初始化关闭控制面板的事件
|
||||||
|
self.dashboard_shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
|
"""加载事件总线和任务并初始化"""
|
||||||
|
|
||||||
platform_tasks = self.load_platform()
|
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||||
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
|
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||||
|
event_bus_task = asyncio.create_task(
|
||||||
|
self.event_bus.dispatch(), name="event_bus"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||||
extra_tasks = []
|
extra_tasks = []
|
||||||
for task in self.star_context._register_tasks:
|
for task in self.star_context._register_tasks:
|
||||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||||
|
|
||||||
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
tasks_ = [event_bus_task, *extra_tasks]
|
||||||
|
|
||||||
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
|
|
||||||
for task in tasks_:
|
for task in tasks_:
|
||||||
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
|
self.curr_tasks.append(
|
||||||
|
asyncio.create_task(self._task_wrapper(task), name=task.get_name())
|
||||||
|
)
|
||||||
|
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
async def _task_wrapper(self, task: asyncio.Task):
|
async def _task_wrapper(self, task: asyncio.Task):
|
||||||
|
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (asyncio.Task): 要执行的异步任务
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass # 任务被取消, 静默处理
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
|
||||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||||
for line in traceback.format_exc().split("\n"):
|
for line in traceback.format_exc().split("\n"):
|
||||||
logger.error(f"| {line}")
|
logger.error(f"| {line}")
|
||||||
logger.error("-------")
|
logger.error("-------")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
|
||||||
self._load()
|
self._load()
|
||||||
logger.info("AstrBot 启动完成。")
|
logger.info("AstrBot 启动完成。")
|
||||||
|
|
||||||
|
# 执行启动完成事件钩子
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnAstrBotLoadedEvent
|
||||||
|
)
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler()
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# 同时运行curr_tasks中的所有任务
|
||||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.event_queue.closed = True
|
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
|
||||||
|
# 请求停止所有正在运行的异步任务
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
await self.provider_manager.terminate()
|
for plugin in self.plugin_manager.context.get_all_stars():
|
||||||
|
try:
|
||||||
|
await self.plugin_manager._terminate_plugin(plugin)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.provider_manager.terminate()
|
||||||
|
await self.platform_manager.terminate()
|
||||||
|
self.dashboard_shutdown_event.set()
|
||||||
|
|
||||||
|
# 再次遍历curr_tasks等待每个任务真正结束
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
@@ -122,13 +211,21 @@ class AstrBotCoreLifecycle:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||||
|
|
||||||
def restart(self):
|
async def restart(self):
|
||||||
self.event_queue.closed = True
|
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||||
threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start()
|
await self.provider_manager.terminate()
|
||||||
|
await self.platform_manager.terminate()
|
||||||
|
self.dashboard_shutdown_event.set()
|
||||||
|
threading.Thread(
|
||||||
|
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
def load_platform(self) -> List[asyncio.Task]:
|
def load_platform(self) -> List[asyncio.Task]:
|
||||||
|
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||||
tasks = []
|
tasks = []
|
||||||
platform_insts = self.platform_manager.get_insts()
|
platform_insts = self.platform_manager.get_insts()
|
||||||
for platform_inst in platform_insts:
|
for platform_inst in platform_insts:
|
||||||
tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name))
|
tasks.append(
|
||||||
|
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
|
||||||
|
)
|
||||||
return tasks
|
return tasks
|
||||||
@@ -1,103 +1,161 @@
|
|||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List, Dict, Any, Tuple
|
||||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
|
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDatabase(abc.ABC):
|
class BaseDatabase(abc.ABC):
|
||||||
'''
|
"""
|
||||||
数据库基类
|
数据库基类
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def insert_base_metrics(self, metrics: dict):
|
def insert_base_metrics(self, metrics: dict):
|
||||||
'''插入基础指标数据'''
|
"""插入基础指标数据"""
|
||||||
self.insert_platform_metrics(metrics['platform_stats'])
|
self.insert_platform_metrics(metrics["platform_stats"])
|
||||||
self.insert_plugin_metrics(metrics['plugin_stats'])
|
self.insert_plugin_metrics(metrics["plugin_stats"])
|
||||||
self.insert_command_metrics(metrics['command_stats'])
|
self.insert_command_metrics(metrics["command_stats"])
|
||||||
self.insert_llm_metrics(metrics['llm_stats'])
|
self.insert_llm_metrics(metrics["llm_stats"])
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def insert_platform_metrics(self, metrics: dict):
|
def insert_platform_metrics(self, metrics: dict):
|
||||||
'''插入平台指标数据'''
|
"""插入平台指标数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def insert_plugin_metrics(self, metrics: dict):
|
def insert_plugin_metrics(self, metrics: dict):
|
||||||
'''插入插件指标数据'''
|
"""插入插件指标数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def insert_command_metrics(self, metrics: dict):
|
def insert_command_metrics(self, metrics: dict):
|
||||||
'''插入指令指标数据'''
|
"""插入指令指标数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def insert_llm_metrics(self, metrics: dict):
|
def insert_llm_metrics(self, metrics: dict):
|
||||||
'''插入 LLM 指标数据'''
|
"""插入 LLM 指标数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||||
'''更新 LLM 历史记录。当不存在 session_id 时插入'''
|
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]:
|
def get_llm_history(
|
||||||
'''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有'''
|
self, session_id: str = None, provider_type: str = None
|
||||||
|
) -> List[LLMHistory]:
|
||||||
|
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||||
'''获取基础统计数据'''
|
"""获取基础统计数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_total_message_count(self) -> int:
|
def get_total_message_count(self) -> int:
|
||||||
'''获取总消息数'''
|
"""获取总消息数"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||||
'''获取基础统计数据(合并)'''
|
"""获取基础统计数据(合并)"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def insert_atri_vision_data(self, vision_data: ATRIVision):
|
def insert_atri_vision_data(self, vision_data: ATRIVision):
|
||||||
'''插入 ATRI 视觉数据'''
|
"""插入 ATRI 视觉数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_atri_vision_data(self) -> List[ATRIVision]:
|
def get_atri_vision_data(self) -> List[ATRIVision]:
|
||||||
'''获取 ATRI 视觉数据'''
|
"""获取 ATRI 视觉数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
def get_atri_vision_data_by_path_or_id(
|
||||||
'''通过 url 或 path 获取 ATRI 视觉数据'''
|
self, url_or_path: str, id: str
|
||||||
|
) -> ATRIVision:
|
||||||
|
"""通过 url 或 path 获取 ATRI 视觉数据"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||||
'''通过 user_id 和 cid 获取 WebChatConversation'''
|
"""通过 user_id 和 cid 获取 Conversation"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
def new_conversation(self, user_id: str, cid: str):
|
||||||
'''新建 WebChatConversation'''
|
"""新建 Conversation"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
|
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||||
'''更新 WebChatConversation'''
|
"""更新 Conversation"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
def delete_conversation(self, user_id: str, cid: str):
|
||||||
'''删除 WebChatConversation'''
|
"""删除 Conversation"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||||
|
"""更新 Conversation 标题"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||||
|
"""更新 Conversation Persona ID"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_all_conversations(
|
||||||
|
self, page: int = 1, page_size: int = 20
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取所有对话,支持分页
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码,从1开始
|
||||||
|
page_size: 每页数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
platforms: List[str] = None,
|
||||||
|
message_types: List[str] = None,
|
||||||
|
search_query: str = None,
|
||||||
|
exclude_ids: List[str] = None,
|
||||||
|
exclude_platforms: List[str] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取筛选后的对话列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
platforms: 平台筛选列表
|
||||||
|
message_types: 消息类型筛选列表
|
||||||
|
search_query: 搜索关键词
|
||||||
|
exclude_ids: 排除的用户ID列表
|
||||||
|
exclude_platforms: 排除的平台列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
112
astrbot/core/db/plugin/sqlite_impl.py
Normal file
112
astrbot/core/db/plugin/sqlite_impl.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import json
|
||||||
|
import aiosqlite
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from .plugin_storage import PluginStorage
|
||||||
|
|
||||||
|
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||||
|
|
||||||
|
|
||||||
|
class SQLitePluginStorage(PluginStorage):
|
||||||
|
"""插件数据的 SQLite 存储实现类。
|
||||||
|
|
||||||
|
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||||
|
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None # Standalone instance of the class
|
||||||
|
_db_conn = None
|
||||||
|
db_path = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""
|
||||||
|
创建或获取 SQLitePluginStorage 的单例实例。
|
||||||
|
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||||
|
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||||
|
"""
|
||||||
|
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||||
|
cls._instance.db_path = DBPATH
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
async def _init_db(self):
|
||||||
|
"""初始化数据库连接(只执行一次)"""
|
||||||
|
if SQLitePluginStorage._db_conn is None:
|
||||||
|
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||||
|
await self._setup_db()
|
||||||
|
|
||||||
|
async def _setup_db(self):
|
||||||
|
"""
|
||||||
|
异步初始化数据库。
|
||||||
|
|
||||||
|
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||||
|
其中 plugin 和 key 组合作为主键。
|
||||||
|
"""
|
||||||
|
await self._db_conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||||
|
plugin TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
PRIMARY KEY (plugin, key)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
await self._db_conn.commit()
|
||||||
|
|
||||||
|
async def set(self, plugin: str, key: str, value: Any):
|
||||||
|
"""
|
||||||
|
异步存储数据。
|
||||||
|
|
||||||
|
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||||
|
值会被序列化为 JSON 字符串后存储。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin: 插件标识符
|
||||||
|
key: 数据键名
|
||||||
|
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||||
|
"""
|
||||||
|
await self._init_db()
|
||||||
|
await self._db_conn.execute(
|
||||||
|
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||||
|
(plugin, key, json.dumps(value)),
|
||||||
|
)
|
||||||
|
await self._db_conn.commit()
|
||||||
|
|
||||||
|
async def get(self, plugin: str, key: str) -> Any:
|
||||||
|
"""
|
||||||
|
异步获取数据。
|
||||||
|
|
||||||
|
从数据库中获取指定插件和键名对应的值,
|
||||||
|
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin: 插件标识符
|
||||||
|
key: 数据键名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 存储的数据值,如果未找到则返回 None
|
||||||
|
"""
|
||||||
|
await self._init_db()
|
||||||
|
async with self._db_conn.execute(
|
||||||
|
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||||
|
(plugin, key),
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return json.loads(row[0]) if row else None
|
||||||
|
|
||||||
|
async def delete(self, plugin: str, key: str):
|
||||||
|
"""
|
||||||
|
异步删除数据。
|
||||||
|
|
||||||
|
从数据库中删除指定插件和键名对应的数据项。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin: 插件标识符
|
||||||
|
key: 要删除的数据键名
|
||||||
|
"""
|
||||||
|
await self._init_db()
|
||||||
|
await self._db_conn.execute(
|
||||||
|
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||||
|
)
|
||||||
|
await self._db_conn.commit()
|
||||||
@@ -1,48 +1,65 @@
|
|||||||
'''指标数据'''
|
"""指标数据"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Platform():
|
class Platform:
|
||||||
|
"""平台使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Provider():
|
class Provider:
|
||||||
|
"""供应商使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Plugin():
|
class Plugin:
|
||||||
|
"""插件使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Command():
|
class Command:
|
||||||
|
"""命令使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Stats():
|
class Stats:
|
||||||
platform: List[Platform] = field(default_factory=list)
|
platform: List[Platform] = field(default_factory=list)
|
||||||
command: List[Command] = field(default_factory=list)
|
command: List[Command] = field(default_factory=list)
|
||||||
llm: List[Provider] = field(default_factory=list)
|
llm: List[Provider] = field(default_factory=list)
|
||||||
|
|
||||||
'''LLM 聊天时持久化的信息'''
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMHistory():
|
class LLMHistory:
|
||||||
|
"""LLM 聊天时持久化的信息"""
|
||||||
|
|
||||||
provider_type: str
|
provider_type: str
|
||||||
session_id: str
|
session_id: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ATRIVision():
|
class ATRIVision:
|
||||||
|
"""Deprecated"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
url_or_path: str
|
url_or_path: str
|
||||||
caption: str
|
caption: str
|
||||||
@@ -54,12 +71,19 @@ class ATRIVision():
|
|||||||
timestamp: int = -1
|
timestamp: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebChatConversation():
|
class Conversation:
|
||||||
|
"""LLM 对话存储
|
||||||
|
|
||||||
|
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||||
|
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||||
|
"""
|
||||||
|
|
||||||
user_id: str
|
user_id: str
|
||||||
cid: str
|
cid: str
|
||||||
history: str = ""
|
history: str = ""
|
||||||
|
"""字符串格式的列表。"""
|
||||||
created_at: int = 0
|
created_at: int = 0
|
||||||
updated_at: int = 0
|
updated_at: int = 0
|
||||||
|
title: str = ""
|
||||||
|
persona_id: str = ""
|
||||||
|
|||||||
@@ -1,15 +1,9 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
|
||||||
Platform,
|
|
||||||
Stats,
|
|
||||||
LLMHistory,
|
|
||||||
ATRIVision,
|
|
||||||
WebChatConversation
|
|
||||||
)
|
|
||||||
from . import BaseDatabase
|
from . import BaseDatabase
|
||||||
from typing import Tuple
|
from typing import Tuple, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase(BaseDatabase):
|
class SQLiteDatabase(BaseDatabase):
|
||||||
@@ -26,6 +20,37 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
c.executescript(sql)
|
c.executescript(sql)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
|
# 检查 webchat_conversation 的 title 字段是否存在
|
||||||
|
c.execute(
|
||||||
|
"""
|
||||||
|
PRAGMA table_info(webchat_conversation)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
res = c.fetchall()
|
||||||
|
has_title = False
|
||||||
|
has_persona_id = False
|
||||||
|
for row in res:
|
||||||
|
if row[1] == "title":
|
||||||
|
has_title = True
|
||||||
|
if row[1] == "persona_id":
|
||||||
|
has_persona_id = True
|
||||||
|
if not has_title:
|
||||||
|
c.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
if not has_persona_id:
|
||||||
|
c.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
c.close()
|
||||||
|
|
||||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
conn.text_factory = str
|
conn.text_factory = str
|
||||||
@@ -51,9 +76,10 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
def insert_platform_metrics(self, metrics: dict):
|
def insert_platform_metrics(self, metrics: dict):
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
||||||
''', (k, v, int(time.time()))
|
""",
|
||||||
|
(k, v, int(time.time())),
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert_plugin_metrics(self, metrics: dict):
|
def insert_plugin_metrics(self, metrics: dict):
|
||||||
@@ -62,57 +88,63 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
def insert_command_metrics(self, metrics: dict):
|
def insert_command_metrics(self, metrics: dict):
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
|
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
|
||||||
''', (k, v, int(time.time()))
|
""",
|
||||||
|
(k, v, int(time.time())),
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert_llm_metrics(self, metrics: dict):
|
def insert_llm_metrics(self, metrics: dict):
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
||||||
''', (k, v, int(time.time()))
|
""",
|
||||||
|
(k, v, int(time.time())),
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||||
res = self.get_llm_history(session_id, provider_type)
|
res = self.get_llm_history(session_id, provider_type)
|
||||||
if res:
|
if res:
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
|
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
|
||||||
''', (content, session_id, provider_type)
|
""",
|
||||||
|
(content, session_id, provider_type),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
|
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
|
||||||
''', (provider_type, session_id, content)
|
""",
|
||||||
|
(provider_type, session_id, content),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple:
|
def get_llm_history(
|
||||||
|
self, session_id: str = None, provider_type: str = None
|
||||||
|
) -> Tuple:
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
where_clause = ""
|
conditions = []
|
||||||
if session_id or provider_type:
|
params = []
|
||||||
where_clause += " WHERE "
|
|
||||||
has = False
|
if session_id:
|
||||||
if session_id:
|
conditions.append("session_id = ?")
|
||||||
where_clause += f"session_id = '{session_id}'"
|
params.append(session_id)
|
||||||
has = True
|
|
||||||
if provider_type:
|
if provider_type:
|
||||||
if has:
|
conditions.append("provider_type = ?")
|
||||||
where_clause += " AND "
|
params.append(provider_type)
|
||||||
where_clause += f"provider_type = '{provider_type}'"
|
|
||||||
|
sql = "SELECT * FROM llm_history"
|
||||||
|
if conditions:
|
||||||
|
sql += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
c.execute(sql, params)
|
||||||
|
|
||||||
c.execute(
|
|
||||||
'''
|
|
||||||
SELECT * FROM llm_history
|
|
||||||
''' + where_clause
|
|
||||||
)
|
|
||||||
res = c.fetchall()
|
res = c.fetchall()
|
||||||
histories = []
|
histories = []
|
||||||
for row in res:
|
for row in res:
|
||||||
@@ -121,7 +153,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
return histories
|
return histories
|
||||||
|
|
||||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||||
'''获取 offset_sec 秒前到现在的基础统计数据'''
|
"""获取 offset_sec 秒前到现在的基础统计数据"""
|
||||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -130,9 +162,10 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT * FROM platform
|
SELECT * FROM platform
|
||||||
''' + where_clause
|
"""
|
||||||
|
+ where_clause
|
||||||
)
|
)
|
||||||
|
|
||||||
platform = []
|
platform = []
|
||||||
@@ -170,16 +203,16 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT SUM(count) FROM platform
|
SELECT SUM(count) FROM platform
|
||||||
'''
|
"""
|
||||||
)
|
)
|
||||||
res = c.fetchone()
|
res = c.fetchone()
|
||||||
c.close()
|
c.close()
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||||
'''获取 offset_sec 秒前到现在的基础统计数据(合并)'''
|
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
|
||||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -188,9 +221,11 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT name, SUM(count), timestamp FROM platform
|
SELECT name, SUM(count), timestamp FROM platform
|
||||||
''' + where_clause + " GROUP BY name"
|
"""
|
||||||
|
+ where_clause
|
||||||
|
+ " GROUP BY name"
|
||||||
)
|
)
|
||||||
|
|
||||||
platform = []
|
platform = []
|
||||||
@@ -201,43 +236,49 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
|
|
||||||
return Stats(platform, [], [])
|
return Stats(platform, [], [])
|
||||||
|
|
||||||
|
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||||
''', (user_id, cid)
|
""",
|
||||||
|
(user_id, cid),
|
||||||
)
|
)
|
||||||
|
|
||||||
res = c.fetchone()
|
res = c.fetchone()
|
||||||
c.close()
|
c.close()
|
||||||
return WebChatConversation(*res)
|
|
||||||
|
|
||||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
if not res:
|
||||||
|
return
|
||||||
|
|
||||||
|
return Conversation(*res)
|
||||||
|
|
||||||
|
def new_conversation(self, user_id: str, cid: str):
|
||||||
history = "[]"
|
history = "[]"
|
||||||
updated_at = int(time.time())
|
updated_at = int(time.time())
|
||||||
created_at = updated_at
|
created_at = updated_at
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||||
''', (user_id, cid, history, updated_at, created_at)
|
""",
|
||||||
|
(user_id, cid, history, updated_at, created_at),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_webchat_conversations(self, user_id: str) -> Tuple:
|
def get_conversations(self, user_id: str) -> Tuple:
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||||
''', (user_id,)
|
""",
|
||||||
|
(user_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
res = c.fetchall()
|
res = c.fetchall()
|
||||||
@@ -247,31 +288,65 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
cid = row[0]
|
cid = row[0]
|
||||||
created_at = row[1]
|
created_at = row[1]
|
||||||
updated_at = row[2]
|
updated_at = row[2]
|
||||||
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
|
title = row[3]
|
||||||
|
persona_id = row[4]
|
||||||
|
conversations.append(
|
||||||
|
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
||||||
|
)
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||||
|
"""更新对话,并且同时更新时间"""
|
||||||
|
updated_at = int(time.time())
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
|
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||||
''', (history, user_id, cid)
|
""",
|
||||||
|
(history, updated_at, user_id, cid),
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
|
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||||
|
""",
|
||||||
|
(title, user_id, cid),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||||
|
self._exec_sql(
|
||||||
|
"""
|
||||||
|
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||||
|
""",
|
||||||
|
(persona_id, user_id, cid),
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_conversation(self, user_id: str, cid: str):
|
||||||
|
self._exec_sql(
|
||||||
|
"""
|
||||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||||
''', (user_id, cid)
|
""",
|
||||||
|
(user_id, cid),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||||
ts = int(time.time())
|
ts = int(time.time())
|
||||||
keywords = ",".join(vision.keywords)
|
keywords = ",".join(vision.keywords)
|
||||||
self._exec_sql(
|
self._exec_sql(
|
||||||
'''
|
"""
|
||||||
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts)
|
""",
|
||||||
|
(
|
||||||
|
vision.id,
|
||||||
|
vision.url_or_path,
|
||||||
|
vision.caption,
|
||||||
|
vision.is_meme,
|
||||||
|
keywords,
|
||||||
|
vision.platform_name,
|
||||||
|
vision.session_id,
|
||||||
|
vision.sender_nickname,
|
||||||
|
ts,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_atri_vision_data(self) -> Tuple:
|
def get_atri_vision_data(self) -> Tuple:
|
||||||
@@ -281,9 +356,9 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT * FROM atri_vision
|
SELECT * FROM atri_vision
|
||||||
'''
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
res = c.fetchall()
|
res = c.fetchall()
|
||||||
@@ -293,16 +368,19 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
c.close()
|
c.close()
|
||||||
return visions
|
return visions
|
||||||
|
|
||||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
def get_atri_vision_data_by_path_or_id(
|
||||||
|
self, url_or_path: str, id: str
|
||||||
|
) -> ATRIVision:
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
c.execute(
|
c.execute(
|
||||||
'''
|
"""
|
||||||
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
|
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
|
||||||
''', (url_or_path, id)
|
""",
|
||||||
|
(url_or_path, id),
|
||||||
)
|
)
|
||||||
|
|
||||||
res = c.fetchone()
|
res = c.fetchone()
|
||||||
@@ -310,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
if res:
|
if res:
|
||||||
return ATRIVision(*res)
|
return ATRIVision(*res)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_all_conversations(
|
||||||
|
self, page: int = 1, page_size: int = 20
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||||
|
try:
|
||||||
|
c = self.conn.cursor()
|
||||||
|
except sqlite3.ProgrammingError:
|
||||||
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取总记录数
|
||||||
|
c.execute("""
|
||||||
|
SELECT COUNT(*) FROM webchat_conversation
|
||||||
|
""")
|
||||||
|
total_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 获取分页数据,按更新时间降序排序
|
||||||
|
c.execute(
|
||||||
|
"""
|
||||||
|
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||||
|
FROM webchat_conversation
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(page_size, offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = c.fetchall()
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||||
|
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
|
||||||
|
safe_cid = str(cid) if cid else "unknown"
|
||||||
|
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||||
|
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"cid": safe_cid,
|
||||||
|
"title": title or f"对话 {display_cid}",
|
||||||
|
"persona_id": persona_id or "",
|
||||||
|
"created_at": created_at or 0,
|
||||||
|
"updated_at": updated_at or 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total_count
|
||||||
|
|
||||||
|
except Exception as _:
|
||||||
|
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||||
|
return [], 0
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
platforms: List[str] = None,
|
||||||
|
message_types: List[str] = None,
|
||||||
|
search_query: str = None,
|
||||||
|
exclude_ids: List[str] = None,
|
||||||
|
exclude_platforms: List[str] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取筛选后的对话列表"""
|
||||||
|
try:
|
||||||
|
c = self.conn.cursor()
|
||||||
|
except sqlite3.ProgrammingError:
|
||||||
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建查询条件
|
||||||
|
where_clauses = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
# 平台筛选
|
||||||
|
if platforms and len(platforms) > 0:
|
||||||
|
platform_conditions = []
|
||||||
|
for platform in platforms:
|
||||||
|
platform_conditions.append("user_id LIKE ?")
|
||||||
|
params.append(f"{platform}:%")
|
||||||
|
|
||||||
|
if platform_conditions:
|
||||||
|
where_clauses.append(f"({' OR '.join(platform_conditions)})")
|
||||||
|
|
||||||
|
# 消息类型筛选
|
||||||
|
if message_types and len(message_types) > 0:
|
||||||
|
message_type_conditions = []
|
||||||
|
for msg_type in message_types:
|
||||||
|
message_type_conditions.append("user_id LIKE ?")
|
||||||
|
params.append(f"%:{msg_type}:%")
|
||||||
|
|
||||||
|
if message_type_conditions:
|
||||||
|
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
||||||
|
|
||||||
|
# 搜索关键词
|
||||||
|
if search_query:
|
||||||
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
|
where_clauses.append(
|
||||||
|
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||||
|
)
|
||||||
|
search_param = f"%{search_query}%"
|
||||||
|
params.extend([search_param, search_param, search_param, search_param])
|
||||||
|
|
||||||
|
# 排除特定用户ID
|
||||||
|
if exclude_ids and len(exclude_ids) > 0:
|
||||||
|
for exclude_id in exclude_ids:
|
||||||
|
where_clauses.append("user_id NOT LIKE ?")
|
||||||
|
params.append(f"{exclude_id}%")
|
||||||
|
|
||||||
|
# 排除特定平台
|
||||||
|
if exclude_platforms and len(exclude_platforms) > 0:
|
||||||
|
for exclude_platform in exclude_platforms:
|
||||||
|
where_clauses.append("user_id NOT LIKE ?")
|
||||||
|
params.append(f"{exclude_platform}:%")
|
||||||
|
|
||||||
|
# 构建完整的 WHERE 子句
|
||||||
|
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||||
|
|
||||||
|
# 构建计数查询
|
||||||
|
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
|
||||||
|
|
||||||
|
# 获取总记录数
|
||||||
|
c.execute(count_sql, params)
|
||||||
|
total_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 构建分页数据查询
|
||||||
|
data_sql = f"""
|
||||||
|
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||||
|
FROM webchat_conversation
|
||||||
|
{where_sql}
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
query_params = params + [page_size, offset]
|
||||||
|
|
||||||
|
# 获取分页数据
|
||||||
|
c.execute(data_sql, query_params)
|
||||||
|
rows = c.fetchall()
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||||
|
# 确保 cid 是字符串类型,否则使用一个默认值
|
||||||
|
safe_cid = str(cid) if cid else "unknown"
|
||||||
|
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||||
|
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"cid": safe_cid,
|
||||||
|
"title": title or f"对话 {display_cid}",
|
||||||
|
"persona_id": persona_id or "",
|
||||||
|
"created_at": created_at or 0,
|
||||||
|
"updated_at": updated_at or 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total_count
|
||||||
|
|
||||||
|
except Exception as _:
|
||||||
|
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||||
|
return [], 0
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
|||||||
@@ -38,9 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||||
user_id TEXT,
|
user_id TEXT, -- 会话 id
|
||||||
cid TEXT,
|
cid TEXT, -- 对话 id
|
||||||
history TEXT,
|
history TEXT,
|
||||||
created_at INTEGER,
|
created_at INTEGER,
|
||||||
updated_at INTEGER
|
updated_at INTEGER,
|
||||||
|
title TEXT,
|
||||||
|
persona_id TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
|
PRAGMA encoding = 'UTF-8';
|
||||||
@@ -1,23 +1,57 @@
|
|||||||
|
"""
|
||||||
|
事件总线, 用于处理事件的分发和处理
|
||||||
|
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||||
|
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
|
class:
|
||||||
|
EventBus: 事件总线, 用于处理事件的分发和处理
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 维护一个异步队列, 来接受各种消息事件
|
||||||
|
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from .platform import AstrMessageEvent
|
from .platform import AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
|
"""事件总线: 用于处理事件的分发和处理
|
||||||
|
|
||||||
|
维护一个异步队列, 来接受各种消息事件
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue # 事件队列
|
||||||
self.pipeline_scheduler = pipeline_scheduler
|
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||||
|
|
||||||
async def dispatch(self):
|
async def dispatch(self):
|
||||||
logger.info("事件总线已打开。")
|
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||||
while True:
|
while True:
|
||||||
event: AstrMessageEvent = await self.event_queue.get()
|
event: AstrMessageEvent = (
|
||||||
self._print_event(event)
|
await self.event_queue.get()
|
||||||
asyncio.create_task(self.pipeline_scheduler.execute(event))
|
) # 从事件队列中获取新的事件
|
||||||
|
self._print_event(event) # 打印日志
|
||||||
|
asyncio.create_task(
|
||||||
|
self.pipeline_scheduler.execute(event)
|
||||||
|
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
def _print_event(self, event: AstrMessageEvent):
|
def _print_event(self, event: AstrMessageEvent):
|
||||||
|
"""用于记录事件信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
"""
|
||||||
|
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||||
if event.get_sender_name():
|
if event.get_sender_name():
|
||||||
logger.info(f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}")
|
logger.info(
|
||||||
|
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||||
|
)
|
||||||
|
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||||
else:
|
else:
|
||||||
logger.info(f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}")
|
logger.info(
|
||||||
|
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||||
|
)
|
||||||
|
|||||||
48
astrbot/core/initial_loader.py
Normal file
48
astrbot/core/initial_loader.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
||||||
|
2. 运行核心生命周期任务和仪表板服务器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core import LogBroker
|
||||||
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
|
|
||||||
|
|
||||||
|
class InitialLoader:
|
||||||
|
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
|
||||||
|
|
||||||
|
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
||||||
|
self.db = db
|
||||||
|
self.logger = logger
|
||||||
|
self.log_broker = log_broker
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||||
|
|
||||||
|
core_task = []
|
||||||
|
try:
|
||||||
|
await core_lifecycle.initialize()
|
||||||
|
core_task = core_lifecycle.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(traceback.format_exc())
|
||||||
|
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||||
|
|
||||||
|
self.dashboard_server = AstrBotDashboard(
|
||||||
|
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||||
|
)
|
||||||
|
task = asyncio.gather(
|
||||||
|
core_task, self.dashboard_server.run()
|
||||||
|
) # 启动核心任务和仪表板服务器
|
||||||
|
|
||||||
|
try:
|
||||||
|
await task # 整个AstrBot在这里运行
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("🌈 正在关闭 AstrBot...")
|
||||||
|
await core_lifecycle.stop()
|
||||||
@@ -1,25 +1,99 @@
|
|||||||
|
"""
|
||||||
|
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||||
|
|
||||||
|
const:
|
||||||
|
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
||||||
|
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
|
||||||
|
|
||||||
|
class:
|
||||||
|
LogBroker: 日志代理类, 用于缓存和分发日志消息
|
||||||
|
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
LogManager: 日志管理器, 用于创建和配置日志记录器
|
||||||
|
|
||||||
|
function:
|
||||||
|
is_plugin_path: 检查文件路径是否来自插件目录
|
||||||
|
get_short_level_name: 将日志级别名称转换为四个字母的缩写
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
|
||||||
|
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
|
||||||
|
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
|
||||||
|
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import colorlog
|
import colorlog
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
# 日志缓存大小
|
||||||
CACHED_SIZE = 200
|
CACHED_SIZE = 200
|
||||||
|
# 日志颜色配置
|
||||||
log_color_config = {
|
log_color_config = {
|
||||||
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
|
"DEBUG": "green",
|
||||||
'WARNING': 'bold_yellow', 'ERROR': 'red',
|
"INFO": "bold_cyan",
|
||||||
'CRITICAL': 'bold_red', 'RESET': 'reset',
|
"WARNING": "bold_yellow",
|
||||||
'asctime': 'green'
|
"ERROR": "red",
|
||||||
|
"CRITICAL": "bold_red",
|
||||||
|
"RESET": "reset",
|
||||||
|
"asctime": "green",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_plugin_path(pathname):
|
||||||
|
"""检查文件路径是否来自插件目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pathname (str): 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
||||||
|
"""
|
||||||
|
if not pathname:
|
||||||
|
return False
|
||||||
|
|
||||||
|
norm_path = os.path.normpath(pathname)
|
||||||
|
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_short_level_name(level_name):
|
||||||
|
"""将日志级别名称转换为四个字母的缩写
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 四个字母的日志级别缩写
|
||||||
|
"""
|
||||||
|
level_map = {
|
||||||
|
"DEBUG": "DBUG",
|
||||||
|
"INFO": "INFO",
|
||||||
|
"WARNING": "WARN",
|
||||||
|
"ERROR": "ERRO",
|
||||||
|
"CRITICAL": "CRIT",
|
||||||
|
}
|
||||||
|
return level_map.get(level_name, level_name[:4].upper())
|
||||||
|
|
||||||
|
|
||||||
class LogBroker:
|
class LogBroker:
|
||||||
|
"""日志代理类, 用于缓存和分发日志消息
|
||||||
|
|
||||||
|
发布-订阅模式
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_cache = deque(maxlen=CACHED_SIZE)
|
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||||
self.subscribers: List[Queue] = []
|
self.subscribers: List[Queue] = [] # 订阅者列表
|
||||||
|
|
||||||
def register(self) -> Queue:
|
def register(self) -> Queue:
|
||||||
'''给每个订阅者返回一个带有日志缓存的队列'''
|
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Queue: 订阅者的队列, 可用于接收日志消息
|
||||||
|
"""
|
||||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||||
for log in self.log_cache:
|
for log in self.log_cache:
|
||||||
q.put_nowait(log)
|
q.put_nowait(log)
|
||||||
@@ -27,11 +101,20 @@ class LogBroker:
|
|||||||
return q
|
return q
|
||||||
|
|
||||||
def unregister(self, q: Queue):
|
def unregister(self, q: Queue):
|
||||||
'''取消订阅'''
|
"""取消订阅
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Queue): 需要取消订阅的队列
|
||||||
|
"""
|
||||||
self.subscribers.remove(q)
|
self.subscribers.remove(q)
|
||||||
|
|
||||||
def publish(self, log_entry: str):
|
def publish(self, log_entry: dict):
|
||||||
'''发布消息'''
|
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
||||||
|
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
||||||
|
"""
|
||||||
self.log_cache.append(log_entry)
|
self.log_cache.append(log_entry)
|
||||||
for q in self.subscribers:
|
for q in self.subscribers:
|
||||||
try:
|
try:
|
||||||
@@ -39,41 +122,124 @@ class LogBroker:
|
|||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LogQueueHandler(logging.Handler):
|
class LogQueueHandler(logging.Handler):
|
||||||
|
"""日志处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
|
||||||
|
继承自 logging.Handler
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, log_broker: LogBroker):
|
def __init__(self, log_broker: LogBroker):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
|
||||||
|
这个方法会在每次日志记录时被调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||||
|
"""
|
||||||
log_entry = self.format(record)
|
log_entry = self.format(record)
|
||||||
self.log_broker.publish(log_entry)
|
self.log_broker.publish(
|
||||||
|
{
|
||||||
|
"level": record.levelname,
|
||||||
|
"time": record.asctime,
|
||||||
|
"data": log_entry,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogManager:
|
class LogManager:
|
||||||
|
"""日志管理器, 用于创建和配置日志记录器
|
||||||
|
|
||||||
|
提供了获取默认日志记录器logger和设置队列处理器的方法
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def GetLogger(cls, log_name: str = 'default'):
|
def GetLogger(cls, log_name: str = "default"):
|
||||||
|
"""获取指定名称的日志记录器logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_name (str): 日志记录器的名称, 默认为 "default"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 返回配置好的日志记录器
|
||||||
|
"""
|
||||||
logger = logging.getLogger(log_name)
|
logger = logging.getLogger(log_name)
|
||||||
|
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
||||||
if logger.hasHandlers():
|
if logger.hasHandlers():
|
||||||
return logger
|
return logger
|
||||||
console_handler = logging.StreamHandler()
|
# 如果logger没有处理器
|
||||||
console_handler.setLevel(logging.DEBUG)
|
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
|
||||||
|
console_handler.setLevel(
|
||||||
|
logging.DEBUG
|
||||||
|
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||||
|
|
||||||
|
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
||||||
console_formatter = colorlog.ColoredFormatter(
|
console_formatter = colorlog.ColoredFormatter(
|
||||||
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
|
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
|
||||||
datefmt='%H:%M:%S',
|
datefmt="%H:%M:%S",
|
||||||
log_colors=log_color_config
|
log_colors=log_color_config,
|
||||||
)
|
)
|
||||||
console_handler.setFormatter(console_formatter)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
class PluginFilter(logging.Filter):
|
||||||
logger.addHandler(console_handler)
|
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
record.plugin_tag = (
|
||||||
|
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
class FileNameFilter(logging.Filter):
|
||||||
|
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
||||||
|
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
|
||||||
|
|
||||||
|
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||||
|
def filter(self, record):
|
||||||
|
dirname = os.path.dirname(record.pathname)
|
||||||
|
record.filename = (
|
||||||
|
os.path.basename(dirname)
|
||||||
|
+ "."
|
||||||
|
+ os.path.basename(record.pathname).replace(".py", "")
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
class LevelNameFilter(logging.Filter):
|
||||||
|
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
|
||||||
|
|
||||||
|
# 添加短日志级别名称
|
||||||
|
def filter(self, record):
|
||||||
|
record.short_levelname = get_short_level_name(record.levelname)
|
||||||
|
return True
|
||||||
|
|
||||||
|
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
|
||||||
|
logger.addFilter(PluginFilter()) # 添加插件过滤器
|
||||||
|
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
|
||||||
|
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
|
||||||
|
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
|
||||||
|
logger.addHandler(console_handler) # 添加处理器到logger
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
||||||
|
"""设置队列处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger (logging.Logger): 日志记录器
|
||||||
|
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
||||||
|
"""
|
||||||
handler = LogQueueHandler(log_broker)
|
handler = LogQueueHandler(log_broker)
|
||||||
handler.setLevel(logging.DEBUG)
|
handler.setLevel(logging.DEBUG)
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
handler.setFormatter(logger.handlers[0].formatter)
|
handler.setFormatter(logger.handlers[0].formatter)
|
||||||
else:
|
else:
|
||||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
# 为队列处理器设置相同格式的formatter
|
||||||
|
handler.setFormatter(
|
||||||
|
logging.Formatter(
|
||||||
|
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
'''
|
"""
|
||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2021 Lxns-Network
|
Copyright (c) 2021 Lxns-Network
|
||||||
@@ -20,21 +20,32 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
import typing as T
|
import typing as T
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
|
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||||
|
|
||||||
|
|
||||||
class ComponentType(Enum):
|
class ComponentType(Enum):
|
||||||
Plain = "Plain"
|
Plain = "Plain" # 纯文本消息
|
||||||
Face = "Face"
|
Face = "Face" # QQ表情
|
||||||
Record = "Record"
|
Record = "Record" # 语音
|
||||||
Video = "Video"
|
Video = "Video" # 视频
|
||||||
At = "At"
|
At = "At" # At
|
||||||
|
Node = "Node" # 转发消息的一个节点
|
||||||
|
Nodes = "Nodes" # 转发消息的多个节点
|
||||||
|
Poke = "Poke" # QQ 戳一戳
|
||||||
|
Image = "Image" # 图片
|
||||||
|
Reply = "Reply" # 回复
|
||||||
|
Forward = "Forward" # 转发消息
|
||||||
|
File = "File" # 文件
|
||||||
|
|
||||||
RPS = "RPS" # TODO
|
RPS = "RPS" # TODO
|
||||||
Dice = "Dice" # TODO
|
Dice = "Dice" # TODO
|
||||||
Shake = "Shake" # TODO
|
Shake = "Shake" # TODO
|
||||||
@@ -43,18 +54,14 @@ class ComponentType(Enum):
|
|||||||
Contact = "Contact" # TODO
|
Contact = "Contact" # TODO
|
||||||
Location = "Location" # TODO
|
Location = "Location" # TODO
|
||||||
Music = "Music"
|
Music = "Music"
|
||||||
Image = "Image"
|
|
||||||
Reply = "Reply"
|
|
||||||
RedBag = "RedBag"
|
RedBag = "RedBag"
|
||||||
Poke = "Poke"
|
|
||||||
Forward = "Forward"
|
|
||||||
Node = "Node"
|
|
||||||
Xml = "Xml"
|
Xml = "Xml"
|
||||||
Json = "Json"
|
Json = "Json"
|
||||||
CardImage = "CardImage"
|
CardImage = "CardImage"
|
||||||
TTS = "TTS"
|
TTS = "TTS"
|
||||||
Unknown = "Unknown"
|
Unknown = "Unknown"
|
||||||
File = "File"
|
|
||||||
|
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageComponent(BaseModel):
|
class BaseMessageComponent(BaseModel):
|
||||||
@@ -69,25 +76,26 @@ class BaseMessageComponent(BaseModel):
|
|||||||
k = "type"
|
k = "type"
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
v = 1 if v else 0
|
v = 1 if v else 0
|
||||||
output += ",%s=%s" % (k, str(v).replace("&", "&") \
|
output += ",%s=%s" % (
|
||||||
.replace(",", ",") \
|
k,
|
||||||
.replace("[", "[") \
|
str(v)
|
||||||
.replace("]", "]"))
|
.replace("&", "&")
|
||||||
|
.replace(",", ",")
|
||||||
|
.replace("[", "[")
|
||||||
|
.replace("]", "]"),
|
||||||
|
)
|
||||||
output += "]"
|
output += "]"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def toDict(self):
|
def toDict(self):
|
||||||
data = dict()
|
data = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k == "type" or v is None:
|
if k == "type" or v is None:
|
||||||
continue
|
continue
|
||||||
if k == "_type":
|
if k == "_type":
|
||||||
k = "type"
|
k = "type"
|
||||||
data[k] = v
|
data[k] = v
|
||||||
return {
|
return {"type": self.type.lower(), "data": data}
|
||||||
"type": self.type.lower(),
|
|
||||||
"data": data
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Plain(BaseMessageComponent):
|
class Plain(BaseMessageComponent):
|
||||||
@@ -101,9 +109,9 @@ class Plain(BaseMessageComponent):
|
|||||||
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
|
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
|
||||||
if not self.convert:
|
if not self.convert:
|
||||||
return self.text
|
return self.text
|
||||||
return self.text.replace("&", "&") \
|
return (
|
||||||
.replace("[", "[") \
|
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||||
.replace("]", "]")
|
)
|
||||||
|
|
||||||
|
|
||||||
class Face(BaseMessageComponent):
|
class Face(BaseMessageComponent):
|
||||||
@@ -142,6 +150,51 @@ class Record(BaseMessageComponent):
|
|||||||
return Record(file=url, **_)
|
return Record(file=url, **_)
|
||||||
raise Exception("not a valid url")
|
raise Exception("not a valid url")
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 语音的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
if self.file and self.file.startswith("file:///"):
|
||||||
|
file_path = self.file[8:]
|
||||||
|
return file_path
|
||||||
|
elif self.file and self.file.startswith("http"):
|
||||||
|
file_path = await download_image_by_url(self.file)
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
elif self.file and self.file.startswith("base64://"):
|
||||||
|
bs64_data = self.file.removeprefix("base64://")
|
||||||
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
|
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
elif os.path.exists(self.file):
|
||||||
|
file_path = self.file
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
|
||||||
|
async def convert_to_base64(self) -> str:
|
||||||
|
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
"""
|
||||||
|
# convert to base64
|
||||||
|
if self.file and self.file.startswith("file:///"):
|
||||||
|
bs64_data = file_to_base64(self.file[8:])
|
||||||
|
elif self.file and self.file.startswith("http"):
|
||||||
|
file_path = await download_image_by_url(self.file)
|
||||||
|
bs64_data = file_to_base64(file_path)
|
||||||
|
elif self.file and self.file.startswith("base64://"):
|
||||||
|
bs64_data = self.file
|
||||||
|
elif os.path.exists(self.file):
|
||||||
|
bs64_data = file_to_base64(self.file)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
return bs64_data
|
||||||
|
|
||||||
|
|
||||||
class Video(BaseMessageComponent):
|
class Video(BaseMessageComponent):
|
||||||
type: ComponentType = "Video"
|
type: ComponentType = "Video"
|
||||||
@@ -275,10 +328,6 @@ class Image(BaseMessageComponent):
|
|||||||
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
||||||
|
|
||||||
def __init__(self, file: T.Optional[str], **_):
|
def __init__(self, file: T.Optional[str], **_):
|
||||||
# for k in _.keys():
|
|
||||||
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
|
|
||||||
# (k == "c" and _[k] not in [2, 3]):
|
|
||||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
|
||||||
super().__init__(file=file, **_)
|
super().__init__(file=file, **_)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -303,14 +352,77 @@ class Image(BaseMessageComponent):
|
|||||||
def fromIO(IO):
|
def fromIO(IO):
|
||||||
return Image.fromBytes(IO.read())
|
return Image.fromBytes(IO.read())
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 图片的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
url = self.url if self.url else self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
image_file_path = url[8:]
|
||||||
|
return image_file_path
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(url)
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
elif url and url.startswith("base64://"):
|
||||||
|
bs64_data = url.removeprefix("base64://")
|
||||||
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
|
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||||
|
with open(image_file_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
elif os.path.exists(url):
|
||||||
|
image_file_path = url
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
|
||||||
|
async def convert_to_base64(self) -> str:
|
||||||
|
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
"""
|
||||||
|
# convert to base64
|
||||||
|
url = self.url if self.url else self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
bs64_data = file_to_base64(url[8:])
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(url)
|
||||||
|
bs64_data = file_to_base64(image_file_path)
|
||||||
|
elif url and url.startswith("base64://"):
|
||||||
|
bs64_data = url
|
||||||
|
elif os.path.exists(url):
|
||||||
|
bs64_data = file_to_base64(url)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
return bs64_data
|
||||||
|
|
||||||
|
|
||||||
class Reply(BaseMessageComponent):
|
class Reply(BaseMessageComponent):
|
||||||
type: ComponentType = "Reply"
|
type: ComponentType = "Reply"
|
||||||
id: int
|
id: T.Union[str, int]
|
||||||
text: T.Optional[str] = ""
|
"""所引用的消息 ID"""
|
||||||
qq: T.Optional[int] = 0
|
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||||
|
"""引用的消息段列表"""
|
||||||
|
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||||
|
"""引用的消息发送者 ID"""
|
||||||
|
sender_nickname: T.Optional[str] = ""
|
||||||
|
"""引用的消息发送者昵称"""
|
||||||
time: T.Optional[int] = 0
|
time: T.Optional[int] = 0
|
||||||
|
"""引用的消息发送时间"""
|
||||||
|
message_str: T.Optional[str] = ""
|
||||||
|
"""解析后的纯文本消息字符串"""
|
||||||
|
sender_str: T.Optional[str] = ""
|
||||||
|
"""被引用的消息纯文本"""
|
||||||
|
|
||||||
|
text: T.Optional[str] = ""
|
||||||
|
"""deprecated"""
|
||||||
|
qq: T.Optional[int] = 0
|
||||||
|
"""deprecated"""
|
||||||
seq: T.Optional[int] = 0
|
seq: T.Optional[int] = 0
|
||||||
|
"""deprecated"""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -325,11 +437,13 @@ class RedBag(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Poke(BaseMessageComponent):
|
class Poke(BaseMessageComponent):
|
||||||
type: ComponentType = "Poke"
|
type: str = ""
|
||||||
qq: int
|
id: T.Optional[int] = 0
|
||||||
|
qq: T.Optional[int] = 0
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, type: str, **_):
|
||||||
super().__init__(**_)
|
type = f"Poke:{type}"
|
||||||
|
super().__init__(type=type, **_)
|
||||||
|
|
||||||
|
|
||||||
class Forward(BaseMessageComponent):
|
class Forward(BaseMessageComponent):
|
||||||
@@ -340,21 +454,29 @@ class Forward(BaseMessageComponent):
|
|||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送
|
class Node(BaseMessageComponent):
|
||||||
|
"""群合并转发消息"""
|
||||||
|
|
||||||
type: ComponentType = "Node"
|
type: ComponentType = "Node"
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0 # 忽略
|
||||||
name: T.Optional[str] = ""
|
name: T.Optional[str] = "" # qq昵称
|
||||||
uin: T.Optional[int] = 0
|
uin: T.Optional[int] = 0 # qq号
|
||||||
content: T.Optional[T.Union[str, list]] = ""
|
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||||
seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么
|
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||||
time: T.Optional[int] = 0
|
time: T.Optional[int] = 0
|
||||||
|
|
||||||
def __init__(self, content: T.Union[str, list], **_):
|
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
|
_content = None
|
||||||
|
if all(isinstance(item, Node) for item in content):
|
||||||
|
_content = [node.toDict() for node in content]
|
||||||
|
else:
|
||||||
_content = ""
|
_content = ""
|
||||||
for chain in content:
|
for chain in content:
|
||||||
_content += chain.toString()
|
_content += chain.toString()
|
||||||
content = _content
|
content = _content
|
||||||
|
elif isinstance(content, Node):
|
||||||
|
content = content.toDict()
|
||||||
super().__init__(content=content, **_)
|
super().__init__(content=content, **_)
|
||||||
|
|
||||||
def toString(self):
|
def toString(self):
|
||||||
@@ -362,6 +484,17 @@ class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForw
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class Nodes(BaseMessageComponent):
|
||||||
|
type: ComponentType = "Nodes"
|
||||||
|
nodes: T.List[Node]
|
||||||
|
|
||||||
|
def __init__(self, nodes: T.List[Node], **_):
|
||||||
|
super().__init__(nodes=nodes, **_)
|
||||||
|
|
||||||
|
def toDict(self):
|
||||||
|
return {"messages": [node.toDict() for node in self.nodes]}
|
||||||
|
|
||||||
|
|
||||||
class Xml(BaseMessageComponent):
|
class Xml(BaseMessageComponent):
|
||||||
type: ComponentType = "Xml"
|
type: ComponentType = "Xml"
|
||||||
data: str
|
data: str
|
||||||
@@ -416,10 +549,12 @@ class Unknown(BaseMessageComponent):
|
|||||||
def toString(self):
|
def toString(self):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class File(BaseMessageComponent):
|
class File(BaseMessageComponent):
|
||||||
'''
|
"""
|
||||||
目前此消息段只适配了 Napcat。
|
目前此消息段只适配了 Napcat。
|
||||||
'''
|
"""
|
||||||
|
|
||||||
type: ComponentType = "File"
|
type: ComponentType = "File"
|
||||||
name: T.Optional[str] = "" # 名字
|
name: T.Optional[str] = "" # 名字
|
||||||
file: T.Optional[str] = "" # url(本地路径)
|
file: T.Optional[str] = "" # url(本地路径)
|
||||||
@@ -428,6 +563,16 @@ class File(BaseMessageComponent):
|
|||||||
super().__init__(name=name, file=file)
|
super().__init__(name=name, file=file)
|
||||||
|
|
||||||
|
|
||||||
|
class WechatEmoji(BaseMessageComponent):
|
||||||
|
type: ComponentType = "WechatEmoji"
|
||||||
|
md5: T.Optional[str] = ""
|
||||||
|
md5_len: T.Optional[int] = 0
|
||||||
|
cdnurl: T.Optional[str] = ""
|
||||||
|
|
||||||
|
def __init__(self, **_):
|
||||||
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
ComponentTypes = {
|
ComponentTypes = {
|
||||||
"plain": Plain,
|
"plain": Plain,
|
||||||
"text": Plain,
|
"text": Plain,
|
||||||
@@ -449,10 +594,12 @@ ComponentTypes = {
|
|||||||
"poke": Poke,
|
"poke": Poke,
|
||||||
"forward": Forward,
|
"forward": Forward,
|
||||||
"node": Node,
|
"node": Node,
|
||||||
|
"nodes": Nodes,
|
||||||
"xml": Xml,
|
"xml": Xml,
|
||||||
"json": Json,
|
"json": Json,
|
||||||
"cardimage": CardImage,
|
"cardimage": CardImage,
|
||||||
"tts": TTS,
|
"tts": TTS,
|
||||||
"unknown": Unknown,
|
"unknown": Unknown,
|
||||||
'file': File,
|
"file": File,
|
||||||
|
"WechatEmoji": WechatEmoji,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,51 +1,80 @@
|
|||||||
import enum
|
import enum
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union, AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
|
from astrbot.core.message.components import (
|
||||||
|
BaseMessageComponent,
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
At,
|
||||||
|
AtAll,
|
||||||
|
)
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageChain():
|
class MessageChain:
|
||||||
'''MessageChain 描述了一整条消息中带有的所有组件。
|
"""MessageChain 描述了一整条消息中带有的所有组件。
|
||||||
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
|
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
`chain` (list): 用于顺序存储各个组件。
|
`chain` (list): 用于顺序存储各个组件。
|
||||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
"""
|
||||||
'''
|
|
||||||
|
|
||||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||||
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
|
||||||
|
|
||||||
def message(self, message: str):
|
def message(self, message: str):
|
||||||
'''添加一条文本消息到消息链 `chain` 中。
|
"""添加一条文本消息到消息链 `chain` 中。
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().message("Hello ").message("world!")
|
CommandResult().message("Hello ").message("world!")
|
||||||
# 输出 Hello world!
|
# 输出 Hello world!
|
||||||
|
|
||||||
'''
|
"""
|
||||||
self.chain.append(Plain(message))
|
self.chain.append(Plain(message))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def at(self, name: str, qq: Union[str, int]):
|
||||||
|
"""添加一条 At 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().at("张三", "12345678910")
|
||||||
|
# 输出 @张三
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chain.append(At(name=name, qq=qq))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def at_all(self):
|
||||||
|
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().at_all()
|
||||||
|
# 输出 @所有人
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chain.append(AtAll())
|
||||||
|
return self
|
||||||
|
|
||||||
@deprecated("请使用 message 方法代替。")
|
@deprecated("请使用 message 方法代替。")
|
||||||
def error(self, message: str):
|
def error(self, message: str):
|
||||||
'''添加一条错误消息到消息链 `chain` 中
|
"""添加一条错误消息到消息链 `chain` 中
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().error("解析失败")
|
CommandResult().error("解析失败")
|
||||||
|
|
||||||
'''
|
"""
|
||||||
self.chain.append(Plain(message))
|
self.chain.append(Plain(message))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def url_image(self, url: str):
|
def url_image(self, url: str):
|
||||||
'''添加一条图片消息(https 链接)到消息链 `chain` 中。
|
"""添加一条图片消息(https 链接)到消息链 `chain` 中。
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
如果需要发送本地图片,请使用 `file_image` 方法。
|
如果需要发送本地图片,请使用 `file_image` 方法。
|
||||||
@@ -54,104 +83,140 @@ class MessageChain():
|
|||||||
|
|
||||||
CommandResult().image("https://example.com/image.jpg")
|
CommandResult().image("https://example.com/image.jpg")
|
||||||
|
|
||||||
'''
|
"""
|
||||||
self.chain.append(Image.fromURL(url))
|
self.chain.append(Image.fromURL(url))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def file_image(self, path: str):
|
def file_image(self, path: str):
|
||||||
'''添加一条图片消息(本地文件路径)到消息链 `chain` 中。
|
"""添加一条图片消息(本地文件路径)到消息链 `chain` 中。
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
如果需要发送网络图片,请使用 `url_image` 方法。
|
如果需要发送网络图片,请使用 `url_image` 方法。
|
||||||
|
|
||||||
CommandResult().image("image.jpg")
|
CommandResult().image("image.jpg")
|
||||||
'''
|
"""
|
||||||
self.chain.append(Image.fromFileSystem(path))
|
self.chain.append(Image.fromFileSystem(path))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def use_t2i(self, use_t2i: bool):
|
def use_t2i(self, use_t2i: bool):
|
||||||
'''设置是否使用文本转图片服务。
|
"""设置是否使用文本转图片服务。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||||
'''
|
"""
|
||||||
self.use_t2i_ = use_t2i
|
self.use_t2i_ = use_t2i
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def is_split(self, is_split: bool):
|
def get_plain_text(self) -> str:
|
||||||
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||||
|
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||||
|
|
||||||
Note:
|
def squash_plain(self):
|
||||||
具体的效果以各适配器实现为准。
|
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||||
|
if not self.chain:
|
||||||
|
return
|
||||||
|
|
||||||
'''
|
new_chain = []
|
||||||
self.is_split_ = is_split
|
first_plain = None
|
||||||
|
plain_texts = []
|
||||||
|
|
||||||
|
for comp in self.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
if first_plain is None:
|
||||||
|
first_plain = comp
|
||||||
|
new_chain.append(comp)
|
||||||
|
plain_texts.append(comp.text)
|
||||||
|
else:
|
||||||
|
new_chain.append(comp)
|
||||||
|
|
||||||
|
if first_plain is not None:
|
||||||
|
first_plain.text = "".join(plain_texts)
|
||||||
|
|
||||||
|
self.chain = new_chain
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class EventResultType(enum.Enum):
|
class EventResultType(enum.Enum):
|
||||||
'''用于描述事件处理的结果类型。
|
"""用于描述事件处理的结果类型。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
CONTINUE: 事件将会继续传播
|
CONTINUE: 事件将会继续传播
|
||||||
STOP: 事件将会终止传播
|
STOP: 事件将会终止传播
|
||||||
'''
|
"""
|
||||||
|
|
||||||
CONTINUE = enum.auto()
|
CONTINUE = enum.auto()
|
||||||
STOP = enum.auto()
|
STOP = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class ResultContentType(enum.Enum):
|
class ResultContentType(enum.Enum):
|
||||||
'''用于描述事件结果的内容的类型。
|
"""用于描述事件结果的内容的类型。"""
|
||||||
'''
|
|
||||||
LLM_RESULT = enum.auto()
|
LLM_RESULT = enum.auto()
|
||||||
'''调用 LLM 产生的结果'''
|
"""调用 LLM 产生的结果"""
|
||||||
GENERAL_RESULT = enum.auto()
|
GENERAL_RESULT = enum.auto()
|
||||||
'''普通的消息结果'''
|
"""普通的消息结果"""
|
||||||
|
STREAMING_RESULT = enum.auto()
|
||||||
|
"""调用 LLM 产生的流式结果"""
|
||||||
|
STREAMING_FINISH= enum.auto()
|
||||||
|
"""流式输出完成"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageEventResult(MessageChain):
|
class MessageEventResult(MessageChain):
|
||||||
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
|
"""MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
|
||||||
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
|
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
`chain` (list): 用于顺序存储各个组件。
|
`chain` (list): 用于顺序存储各个组件。
|
||||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
|
||||||
`result_type` (EventResultType): 事件处理的结果类型。
|
`result_type` (EventResultType): 事件处理的结果类型。
|
||||||
'''
|
"""
|
||||||
|
|
||||||
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
|
result_type: Optional[EventResultType] = field(
|
||||||
|
default_factory=lambda: EventResultType.CONTINUE
|
||||||
|
)
|
||||||
|
|
||||||
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
|
result_content_type: Optional[ResultContentType] = field(
|
||||||
|
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
||||||
|
)
|
||||||
|
|
||||||
def stop_event(self) -> 'MessageEventResult':
|
async_stream: Optional[AsyncGenerator] = None
|
||||||
'''终止事件传播。
|
"""异步流"""
|
||||||
'''
|
|
||||||
|
def stop_event(self) -> "MessageEventResult":
|
||||||
|
"""终止事件传播。"""
|
||||||
self.result_type = EventResultType.STOP
|
self.result_type = EventResultType.STOP
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def continue_event(self) -> 'MessageEventResult':
|
def continue_event(self) -> "MessageEventResult":
|
||||||
'''继续事件传播。
|
"""继续事件传播。"""
|
||||||
'''
|
|
||||||
self.result_type = EventResultType.CONTINUE
|
self.result_type = EventResultType.CONTINUE
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def is_stopped(self) -> bool:
|
def is_stopped(self) -> bool:
|
||||||
'''
|
"""
|
||||||
是否终止事件传播。
|
是否终止事件传播。
|
||||||
'''
|
"""
|
||||||
return self.result_type == EventResultType.STOP
|
return self.result_type == EventResultType.STOP
|
||||||
|
|
||||||
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
|
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||||
'''设置事件处理的结果类型。
|
"""设置异步流。"""
|
||||||
|
self.async_stream = stream
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
||||||
|
"""设置事件处理的结果类型。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_type (EventResultType): 事件处理的结果类型。
|
result_type (EventResultType): 事件处理的结果类型。
|
||||||
'''
|
"""
|
||||||
self.result_content_type = typ
|
self.result_content_type = typ
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def is_llm_result(self) -> bool:
|
def is_llm_result(self) -> bool:
|
||||||
'''是否为 LLM 结果。
|
"""是否为 LLM 结果。"""
|
||||||
'''
|
|
||||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||||
|
|
||||||
|
|
||||||
|
# 为了兼容旧版代码,保留 CommandResult 的别名
|
||||||
CommandResult = MessageEventResult
|
CommandResult = MessageEventResult
|
||||||
@@ -1,32 +1,41 @@
|
|||||||
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
|
from astrbot.core.message.message_event_result import (
|
||||||
|
MessageEventResult,
|
||||||
|
EventResultType,
|
||||||
|
)
|
||||||
|
|
||||||
from .waking_check.stage import WakingCheckStage
|
from .waking_check.stage import WakingCheckStage
|
||||||
from .whitelist_check.stage import WhitelistCheckStage
|
from .whitelist_check.stage import WhitelistCheckStage
|
||||||
|
from .rate_limit_check.stage import RateLimitStage
|
||||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||||
|
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||||
from .preprocess_stage.stage import PreProcessStage
|
from .preprocess_stage.stage import PreProcessStage
|
||||||
from .process_stage.stage import ProcessStage
|
from .process_stage.stage import ProcessStage
|
||||||
from .result_decorate.stage import ResultDecorateStage
|
from .result_decorate.stage import ResultDecorateStage
|
||||||
from .respond.stage import RespondStage
|
from .respond.stage import RespondStage
|
||||||
|
|
||||||
|
# 管道阶段顺序
|
||||||
STAGES_ORDER = [
|
STAGES_ORDER = [
|
||||||
"WakingCheckStage", # 检查是否需要唤醒
|
"WakingCheckStage", # 检查是否需要唤醒
|
||||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||||
"RateLimitCheckStage", # 检查会话是否超过频率限制
|
"RateLimitStage", # 检查会话是否超过频率限制
|
||||||
"ContentSafetyCheckStage", # 检查内容安全
|
"ContentSafetyCheckStage", # 检查内容安全
|
||||||
|
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||||
"PreProcessStage", # 预处理
|
"PreProcessStage", # 预处理
|
||||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||||
"RespondStage" # 发送消息
|
"RespondStage", # 发送消息
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WakingCheckStage",
|
"WakingCheckStage",
|
||||||
"WhitelistCheckStage",
|
"WhitelistCheckStage",
|
||||||
|
"RateLimitStage",
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
|
"PlatformCompatibilityStage",
|
||||||
"PreProcessStage",
|
"PreProcessStage",
|
||||||
"ProcessStage",
|
"ProcessStage",
|
||||||
"ResultDecorateStage",
|
"ResultDecorateStage",
|
||||||
"RespondStage",
|
"RespondStage",
|
||||||
"MessageEventResult",
|
"MessageEventResult",
|
||||||
"EventResultType"
|
"EventResultType",
|
||||||
]
|
]
|
||||||
@@ -6,23 +6,32 @@ from astrbot.core.message.message_event_result import MessageEventResult
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from .strategies.strategy import StrategySelector
|
from .strategies.strategy import StrategySelector
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class ContentSafetyCheckStage(Stage):
|
class ContentSafetyCheckStage(Stage):
|
||||||
'''检查内容安全
|
"""检查内容安全
|
||||||
|
|
||||||
当前只会检查文本的。
|
当前只会检查文本的。
|
||||||
'''
|
"""
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext):
|
async def initialize(self, ctx: PipelineContext):
|
||||||
config = ctx.astrbot_config['content_safety']
|
config = ctx.astrbot_config["content_safety"]
|
||||||
self.strategy_selector = StrategySelector(config)
|
self.strategy_selector = StrategySelector(config)
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
async def process(
|
||||||
'''检查内容安全'''
|
self, event: AstrMessageEvent, check_text: str = None
|
||||||
ok, info = self.strategy_selector.check(event.get_message_str())
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
"""检查内容安全"""
|
||||||
|
text = check_text if check_text else event.get_message_str()
|
||||||
|
ok, info = self.strategy_selector.check(text)
|
||||||
if not ok:
|
if not ok:
|
||||||
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
|
if event.is_at_or_wake_command:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
logger.info(f"内容安全检查不通过,原因:{info}")
|
logger.info(f"内容安全检查不通过,原因:{info}")
|
||||||
return
|
return
|
||||||
event.continue_event()
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import abc
|
import abc
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
class ContentSafetyStrategy(abc.ABC):
|
|
||||||
|
|
||||||
|
class ContentSafetyStrategy(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def check(self, content: str) -> Tuple[bool, str]:
|
def check(self, content: str) -> Tuple[bool, str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -1,30 +1,30 @@
|
|||||||
'''
|
"""
|
||||||
使用此功能应该先 pip install baidu-aip
|
使用此功能应该先 pip install baidu-aip
|
||||||
'''
|
"""
|
||||||
|
|
||||||
from . import ContentSafetyStrategy
|
from . import ContentSafetyStrategy
|
||||||
from aip import AipContentCensor
|
from aip import AipContentCensor
|
||||||
|
|
||||||
|
|
||||||
class BaiduAipStrategy(ContentSafetyStrategy):
|
class BaiduAipStrategy(ContentSafetyStrategy):
|
||||||
def __init__(self, appid: str, ak: str, sk: str) -> None:
|
def __init__(self, appid: str, ak: str, sk: str) -> None:
|
||||||
self.app_id = appid
|
self.app_id = appid
|
||||||
self.api_key = ak
|
self.api_key = ak
|
||||||
self.secret_key = sk
|
self.secret_key = sk
|
||||||
self.client = AipContentCensor(self.app_id,
|
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
||||||
self.api_key,
|
|
||||||
self.secret_key)
|
|
||||||
|
|
||||||
def check(self, content: str):
|
def check(self, content: str):
|
||||||
res = self.client.textCensorUserDefined(content)
|
res = self.client.textCensorUserDefined(content)
|
||||||
if 'conclusionType' not in res:
|
if "conclusionType" not in res:
|
||||||
return False, ""
|
return False, ""
|
||||||
if res['conclusionType'] == 1:
|
if res["conclusionType"] == 1:
|
||||||
return True, ""
|
return True, ""
|
||||||
else:
|
else:
|
||||||
if 'data' not in res:
|
if "data" not in res:
|
||||||
return False, ""
|
return False, ""
|
||||||
count = len(res['data'])
|
count = len(res["data"])
|
||||||
info = f"百度审核服务发现 {count} 处违规:\n"
|
info = f"百度审核服务发现 {count} 处违规:\n"
|
||||||
for i in res['data']:
|
for i in res["data"]:
|
||||||
info += f"{i['msg']};\n"
|
info += f"{i['msg']};\n"
|
||||||
info += "\n判断结果:"+res['conclusion']
|
info += "\n判断结果:" + res["conclusion"]
|
||||||
return False, info
|
return False, info
|
||||||
@@ -1,20 +1,20 @@
|
|||||||
import re
|
import re
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import base64
|
|
||||||
from . import ContentSafetyStrategy
|
from . import ContentSafetyStrategy
|
||||||
|
|
||||||
|
|
||||||
class KeywordsStrategy(ContentSafetyStrategy):
|
class KeywordsStrategy(ContentSafetyStrategy):
|
||||||
def __init__(self, extra_keywords: list) -> None:
|
def __init__(self, extra_keywords: list) -> None:
|
||||||
self.keywords = []
|
self.keywords = []
|
||||||
if extra_keywords is None:
|
if extra_keywords is None:
|
||||||
extra_keywords = []
|
extra_keywords = []
|
||||||
self.keywords.extend(extra_keywords)
|
self.keywords.extend(extra_keywords)
|
||||||
keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words')
|
# keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
|
||||||
# internal keywords
|
# internal keywords
|
||||||
if os.path.exists(keywords_path):
|
# if os.path.exists(keywords_path):
|
||||||
with open(keywords_path, "r", encoding="utf-8") as f:
|
# with open(keywords_path, "r", encoding="utf-8") as f:
|
||||||
self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'])
|
# self.keywords.extend(
|
||||||
|
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||||
|
# )
|
||||||
|
|
||||||
def check(self, content: str) -> bool:
|
def check(self, content: str) -> bool:
|
||||||
for keyword in self.keywords:
|
for keyword in self.keywords:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from . import ContentSafetyStrategy
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
class StrategySelector:
|
class StrategySelector:
|
||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
self.enabled_strategies: List[ContentSafetyStrategy] = []
|
self.enabled_strategies: List[ContentSafetyStrategy] = []
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
ewogICAgImtleXdvcmRzIjogWwogICAgICAgICLkuaDov5HlubMiLAogICAgICAgICLog6HplKbmtpsiLAogICAgICAgICLmsZ/ms73msJEiLAogICAgICAgICLmuKnlrrblrp0iLAogICAgICAgICLmnY7lhYvlvLoiLAogICAgICAgICLmnY7plb/mmKUiLAogICAgICAgICLmr5vms73kuJwiLAogICAgICAgICLpgpPlsI/lubMiLAogICAgICAgICLlkajmganmnaUiLAogICAgICAgICLnpL7kvJrkuLvkuYkiLAogICAgICAgICLlhbHkuqflhZoiLAogICAgICAgICLlhbHkuqfkuLvkuYkiLAogICAgICAgICLlpKfpmYblrpjmlrkiLAogICAgICAgICLljJfkuqzmlL/mnYMiLAogICAgICAgICLkuK3ljY7luJ3lm70iLAogICAgICAgICLkuK3lm73mlL/lupwiLAogICAgICAgICLlhbHni5ciLAogICAgICAgICLlha3lm5vkuovku7YiLAogICAgICAgICLlpKnlronpl6giLAogICAgICAgICLlha3lm5siLAogICAgICAgICLmlL/msrvlsYDluLjlp5QiLAogICAgICAgICLlrabmva4iLAogICAgICAgICLlhavkuZ0iLAogICAgICAgICLkuozljYHlpKciLAogICAgICAgICLmsJHov5vlhZoiLAogICAgICAgICLlj7Dni6wiLAogICAgICAgICLlj7Dmub7ni6znq4siLAogICAgICAgICLlj7Dmub7lm70iLAogICAgICAgICLlm73msJHlhZoiLAogICAgICAgICLlj7Dmub7msJHlm70iLAogICAgICAgICLkuK3ljY7msJHlm70iLAogICAgICAgICJwb3JuaHViIiwKICAgICAgICAiUG9ybmh1YiIsCiAgICAgICAgIuS9nOeIsSIsCiAgICAgICAgIuWBmueIsSIsCiAgICAgICAgIuaAp+S6pCIsCiAgICAgICAgIuiHquaFsCIsCiAgICAgICAgIumYtOiMjiIsCiAgICAgICAgIua3q+WmhyIsCiAgICAgICAgIuiCm+S6pCIsCiAgICAgICAgIuS6pOmFjSIsCiAgICAgICAgIuaAp+WFs+ezuyIsCiAgICAgICAgIuaAp+a0u+WKqCIsCiAgICAgICAgIuiJsuaDhSIsCiAgICAgICAgIuiJsuWbviIsCiAgICAgICAgIuijuOS9kyIsCiAgICAgICAgIuWwj+eptCIsCiAgICAgICAgIua3q+iNoSIsCiAgICAgICAgIuaAp+eIsSIsCiAgICAgICAgIua4r+eLrCIsCiAgICAgICAgIuazlei9ruWKnyIsCiAgICAgICAgIuWFreWbmyIKICAgIF0KfQ==
|
|
||||||
@@ -2,7 +2,10 @@ from dataclasses import dataclass
|
|||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from astrbot.core.star import PluginManager
|
from astrbot.core.star import PluginManager
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineContext:
|
class PipelineContext:
|
||||||
astrbot_config: AstrBotConfig
|
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||||
plugin_manager: PluginManager
|
|
||||||
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||||
|
plugin_manager: PluginManager # 插件管理器对象
|
||||||
|
|||||||
56
astrbot/core/pipeline/platform_compatibility/stage.py
Normal file
56
astrbot/core/pipeline/platform_compatibility/stage.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from ..stage import Stage, register_stage
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from typing import Union, AsyncGenerator
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
@register_stage
|
||||||
|
class PlatformCompatibilityStage(Stage):
|
||||||
|
"""检查所有处理器的平台兼容性。
|
||||||
|
|
||||||
|
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
"""初始化平台兼容性检查阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
# 获取当前平台ID
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
|
||||||
|
# 获取已激活的处理器
|
||||||
|
activated_handlers = event.get_extra("activated_handlers")
|
||||||
|
if activated_handlers is None:
|
||||||
|
activated_handlers = []
|
||||||
|
|
||||||
|
# 标记不兼容的处理器
|
||||||
|
for handler in activated_handlers:
|
||||||
|
if not isinstance(handler, StarHandlerMetadata):
|
||||||
|
continue
|
||||||
|
# 检查处理器是否在当前平台启用
|
||||||
|
enabled = handler.is_enabled_for_platform(platform_id)
|
||||||
|
if not enabled:
|
||||||
|
if handler.handler_module_path in star_map:
|
||||||
|
plugin_name = star_map[handler.handler_module_path].name
|
||||||
|
logger.debug(
|
||||||
|
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||||
|
)
|
||||||
|
# 设置处理器为平台不兼容状态
|
||||||
|
# TODO: 更好的标记方式
|
||||||
|
handler.platform_compatible = False
|
||||||
|
else:
|
||||||
|
# 确保处理器为平台兼容状态
|
||||||
|
handler.platform_compatible = True
|
||||||
|
|
||||||
|
# 更新已激活的处理器列表
|
||||||
|
event.set_extra("activated_handlers", activated_handlers)
|
||||||
@@ -7,22 +7,23 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Plain, Record, Image
|
from astrbot.core.message.components import Plain, Record, Image
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class PreProcessStage(Stage):
|
class PreProcessStage(Stage):
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.config = ctx.astrbot_config
|
self.config = ctx.astrbot_config
|
||||||
self.plugin_manager = ctx.plugin_manager
|
self.plugin_manager = ctx.plugin_manager
|
||||||
|
|
||||||
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
|
self.stt_settings: dict = self.config.get("provider_stt_settings", {})
|
||||||
self.platform_settings: dict = self.config.get('platform_settings', {})
|
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||||
|
|
||||||
|
async def process(
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
self, event: AstrMessageEvent
|
||||||
'''在处理事件之前的预处理'''
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
"""在处理事件之前的预处理"""
|
||||||
# 路径映射
|
# 路径映射
|
||||||
if mappings := self.platform_settings.get('path_mapping', []):
|
if mappings := self.platform_settings.get("path_mapping", []):
|
||||||
# 支持 Record,Image 消息段的路径映射。
|
# 支持 Record,Image 消息段的路径映射。
|
||||||
message_chain = event.get_messages()
|
message_chain = event.get_messages()
|
||||||
|
|
||||||
@@ -40,9 +41,11 @@ class PreProcessStage(Stage):
|
|||||||
message_chain[idx] = component
|
message_chain[idx] = component
|
||||||
|
|
||||||
# STT
|
# STT
|
||||||
if self.stt_settings.get('enable', False):
|
if self.stt_settings.get("enable", False):
|
||||||
# TODO: 独立
|
# TODO: 独立
|
||||||
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
stt_provider = (
|
||||||
|
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||||
|
)
|
||||||
if stt_provider:
|
if stt_provider:
|
||||||
message_chain = event.get_messages()
|
message_chain = event.get_messages()
|
||||||
for idx, component in enumerate(message_chain):
|
for idx, component in enumerate(message_chain):
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
'''
|
|
||||||
Dify 调用 Stage
|
|
||||||
'''
|
|
||||||
import traceback
|
|
||||||
from typing import Union, AsyncGenerator
|
|
||||||
from ...context import PipelineContext
|
|
||||||
from ..stage import Stage
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
|
|
||||||
from astrbot.core.message.components import Image
|
|
||||||
from astrbot.core import logger
|
|
||||||
from astrbot.core.utils.metrics import Metric
|
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
|
||||||
|
|
||||||
class DifyRequestSubStage(Stage):
|
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
|
||||||
self.ctx = ctx
|
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
|
||||||
req: ProviderRequest = None
|
|
||||||
|
|
||||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
|
||||||
if provider.meta().type != "dify":
|
|
||||||
return
|
|
||||||
|
|
||||||
if event.get_extra("provider_request"):
|
|
||||||
req = event.get_extra("provider_request")
|
|
||||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
|
||||||
else:
|
|
||||||
req = ProviderRequest(prompt="", image_urls=[])
|
|
||||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
|
||||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
|
||||||
return
|
|
||||||
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
|
||||||
for comp in event.message_obj.message:
|
|
||||||
if isinstance(comp, Image):
|
|
||||||
image_url = comp.url if comp.url else comp.file
|
|
||||||
req.image_urls.append(image_url)
|
|
||||||
req.session_id = event.session_id
|
|
||||||
event.set_extra("provider_request", req)
|
|
||||||
|
|
||||||
if not req.prompt:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
|
|
||||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
|
||||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
|
||||||
|
|
||||||
if llm_response.role == 'assistant':
|
|
||||||
# text completion
|
|
||||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
|
||||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
|
||||||
yield # rick roll
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
|
|
||||||
return
|
|
||||||
@@ -1,31 +1,63 @@
|
|||||||
'''
|
"""
|
||||||
本地 Agent 模式的 LLM 调用 Stage
|
本地 Agent 模式的 LLM 调用 Stage
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ...context import PipelineContext
|
from ...context import PipelineContext
|
||||||
from ..stage import Stage
|
from ..stage import Stage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
|
from astrbot.core.message.message_event_result import (
|
||||||
|
MessageEventResult,
|
||||||
|
ResultContentType,
|
||||||
|
MessageChain,
|
||||||
|
)
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from astrbot.core.provider.entities import (
|
||||||
|
ProviderRequest,
|
||||||
|
LLMResponse,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
AssistantMessageSegment,
|
||||||
|
ToolCallsResult,
|
||||||
|
)
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestSubStage(Stage):
|
class LLMRequestSubStage(Stage):
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
|
self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list
|
||||||
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
|
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
||||||
|
"wake_prefix"
|
||||||
|
] # str
|
||||||
|
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||||
|
"max_context_length"
|
||||||
|
] # int
|
||||||
|
self.dequeue_context_length = min(
|
||||||
|
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||||
|
self.max_context_length - 1,
|
||||||
|
) # int
|
||||||
|
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||||
|
"streaming_response"
|
||||||
|
] # bool
|
||||||
|
|
||||||
for bwp in self.bot_wake_prefixs:
|
for bwp in self.bot_wake_prefixs:
|
||||||
if self.provider_wake_prefix.startswith(bwp):
|
if self.provider_wake_prefix.startswith(bwp):
|
||||||
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
|
logger.info(
|
||||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
|
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
|
||||||
|
)
|
||||||
|
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent, _nested: bool = False
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
req: ProviderRequest = None
|
req: ProviderRequest = None
|
||||||
|
|
||||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||||
@@ -34,76 +66,470 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
if event.get_extra("provider_request"):
|
if event.get_extra("provider_request"):
|
||||||
req = event.get_extra("provider_request")
|
req = event.get_extra("provider_request")
|
||||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
assert isinstance(
|
||||||
|
req, ProviderRequest
|
||||||
|
), "provider_request 必须是 ProviderRequest 类型。"
|
||||||
|
|
||||||
|
if req.conversation:
|
||||||
|
all_contexts = json.loads(req.conversation.history)
|
||||||
|
req.contexts = self._process_tool_message_pairs(
|
||||||
|
all_contexts, remove_tags=True
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
req = ProviderRequest(prompt="", image_urls=[])
|
req = ProviderRequest(prompt="", image_urls=[])
|
||||||
if self.provider_wake_prefix:
|
if self.provider_wake_prefix:
|
||||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||||
return
|
return
|
||||||
req.prompt = event.message_str[len(self.provider_wake_prefix):]
|
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||||
for comp in event.message_obj.message:
|
for comp in event.message_obj.message:
|
||||||
if isinstance(comp, Image):
|
if isinstance(comp, Image):
|
||||||
image_url = comp.url if comp.url else comp.file
|
image_path = await comp.convert_to_file_path()
|
||||||
req.image_urls.append(image_url)
|
req.image_urls.append(image_path)
|
||||||
req.session_id = event.session_id
|
|
||||||
|
# 获取对话上下文
|
||||||
|
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
||||||
|
event.unified_msg_origin
|
||||||
|
)
|
||||||
|
if not conversation_id:
|
||||||
|
conversation_id = await self.conv_manager.new_conversation(
|
||||||
|
event.unified_msg_origin
|
||||||
|
)
|
||||||
|
conversation = await self.conv_manager.get_conversation(
|
||||||
|
event.unified_msg_origin, conversation_id
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
conversation_id = await self.conv_manager.new_conversation(
|
||||||
|
event.unified_msg_origin
|
||||||
|
)
|
||||||
|
conversation = await self.conv_manager.get_conversation(
|
||||||
|
event.unified_msg_origin, conversation_id
|
||||||
|
)
|
||||||
|
req.conversation = conversation
|
||||||
|
req.contexts = json.loads(conversation.history)
|
||||||
|
|
||||||
event.set_extra("provider_request", req)
|
event.set_extra("provider_request", req)
|
||||||
session_provider_context = provider.session_memory.get(event.session_id)
|
|
||||||
req.contexts = session_provider_context if session_provider_context else []
|
|
||||||
|
|
||||||
if not req.prompt and not req.image_urls:
|
if not req.prompt and not req.image_urls:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 执行请求 LLM 前事件。
|
# 执行请求 LLM 前事件钩子。
|
||||||
# 装饰 system_prompt 等功能
|
# 装饰 system_prompt 等功能
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
# 获取当前平台ID
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||||
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
await handler.handler(event, req)
|
await handler.handler(event, req)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
try:
|
if event.is_stopped():
|
||||||
logger.debug(f"提供商请求 Payload: {req.__dict__}")
|
logger.info(
|
||||||
if _nested:
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
)
|
||||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
return
|
||||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
|
||||||
|
|
||||||
if llm_response.role == 'assistant':
|
if isinstance(req.contexts, str):
|
||||||
# text completion
|
req.contexts = json.loads(req.contexts)
|
||||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
|
||||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
# max context length
|
||||||
elif llm_response.role == 'tool':
|
if (
|
||||||
# function calling
|
self.max_context_length != -1 # -1 为不限制
|
||||||
function_calling_result = {}
|
and len(req.contexts) // 2 > self.max_context_length
|
||||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
):
|
||||||
func_tool = req.func_tool.get_func(func_tool_name)
|
logger.debug("上下文长度超过限制,将截断。")
|
||||||
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
req.contexts = req.contexts[
|
||||||
|
-(self.max_context_length - self.dequeue_context_length) * 2 :
|
||||||
|
]
|
||||||
|
|
||||||
|
# session_id
|
||||||
|
if not req.session_id:
|
||||||
|
req.session_id = event.unified_msg_origin
|
||||||
|
|
||||||
|
async def requesting(req: ProviderRequest):
|
||||||
try:
|
try:
|
||||||
# 尝试调用工具函数
|
need_loop = True
|
||||||
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
|
while need_loop:
|
||||||
async for resp in wrapper:
|
need_loop = False
|
||||||
if resp is not None:
|
logger.debug(f"提供商请求 Payload: {req}")
|
||||||
function_calling_result[func_tool_name] = resp
|
|
||||||
|
final_llm_response = None
|
||||||
|
|
||||||
|
if self.streaming_response:
|
||||||
|
stream = provider.text_chat_stream(**req.__dict__)
|
||||||
|
async for llm_response in stream:
|
||||||
|
if llm_response.is_chunk:
|
||||||
|
if llm_response.result_chain:
|
||||||
|
yield llm_response.result_chain # MessageChain
|
||||||
|
else:
|
||||||
|
yield MessageChain().message(
|
||||||
|
llm_response.completion_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_llm_response = llm_response
|
||||||
|
else:
|
||||||
|
final_llm_response = await provider.text_chat(
|
||||||
|
**req.__dict__
|
||||||
|
) # 请求 LLM
|
||||||
|
|
||||||
|
if not final_llm_response:
|
||||||
|
raise Exception("LLM response is None.")
|
||||||
|
|
||||||
|
# 执行 LLM 响应后的事件钩子。
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnLLMResponseEvent
|
||||||
|
)
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler(event, final_llm_response)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if event.is_stopped():
|
||||||
|
logger.info(
|
||||||
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.streaming_response:
|
||||||
|
# 流式输出的处理
|
||||||
|
async for result in self._handle_llm_stream_response(
|
||||||
|
event, req, final_llm_response
|
||||||
|
):
|
||||||
|
if isinstance(result, ProviderRequest):
|
||||||
|
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||||
|
req = result
|
||||||
|
need_loop = True
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
event.clear_result() # 清除上一个 handler 的结果
|
else:
|
||||||
except BaseException as e:
|
# 非流式输出的处理
|
||||||
logger.warning(traceback.format_exc())
|
async for result in self._handle_llm_response(
|
||||||
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
|
event, req, final_llm_response
|
||||||
if function_calling_result:
|
):
|
||||||
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
|
if isinstance(result, ProviderRequest):
|
||||||
# 我们重新执行一遍这个 stage
|
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
req = result
|
||||||
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
|
need_loop = True
|
||||||
for tool_name, tool_result in function_calling_result.items():
|
else:
|
||||||
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
|
||||||
req.prompt += extra_prompt
|
|
||||||
async for _ in self.process(event, _nested=True):
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
llm_tick=1,
|
||||||
|
model_name=provider.get_model(),
|
||||||
|
provider_type=provider.meta().type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到历史记录
|
||||||
|
await self._save_to_history(event, req, final_llm_response)
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.streaming_response:
|
||||||
|
event.set_extra("tool_call_result", None)
|
||||||
|
async for _ in requesting(req):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||||
|
.set_async_stream(requesting(req))
|
||||||
|
)
|
||||||
|
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||||
|
yield
|
||||||
|
|
||||||
|
if event.get_extra("tool_call_result"):
|
||||||
|
event.set_result(event.get_extra("tool_call_result"))
|
||||||
|
event.set_extra("tool_call_result", None)
|
||||||
|
yield
|
||||||
|
|
||||||
|
async def _handle_llm_response(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理非流式 LLM 响应。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# text completion
|
||||||
|
if llm_response.result_chain:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult(
|
||||||
|
chain=llm_response.result_chain.chain
|
||||||
|
).set_result_content_type(ResultContentType.LLM_RESULT)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.message(llm_response.completion_text)
|
||||||
|
.set_result_content_type(ResultContentType.LLM_RESULT)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "err":
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "tool":
|
||||||
|
# 处理函数工具调用
|
||||||
|
async for result in self._handle_function_tools(event, req, llm_response):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
async def _handle_llm_stream_response(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理流式 LLM 响应。
|
||||||
|
|
||||||
|
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# text completion
|
||||||
|
if llm_response.result_chain:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult(
|
||||||
|
chain=llm_response.result_chain.chain
|
||||||
|
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.message(llm_response.completion_text)
|
||||||
|
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "err":
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "tool":
|
||||||
|
# 处理函数工具调用
|
||||||
|
async for result in self._handle_function_tools(event, req, llm_response):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
async def _handle_function_tools(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理函数工具调用。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
# function calling
|
||||||
|
tool_call_result: list[ToolCallMessageSegment] = []
|
||||||
|
logger.info(
|
||||||
|
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||||
|
)
|
||||||
|
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||||
|
llm_response.tools_call_name,
|
||||||
|
llm_response.tools_call_args,
|
||||||
|
llm_response.tools_call_ids,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
func_tool = req.func_tool.get_func(func_tool_name)
|
||||||
|
if func_tool.origin == "mcp":
|
||||||
|
logger.info(
|
||||||
|
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||||
|
)
|
||||||
|
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||||
|
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||||
|
if res:
|
||||||
|
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=res.content[0].text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 获取处理器,过滤掉平台不兼容的处理器
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
star_md = star_map.get(func_tool.handler_module_path)
|
||||||
|
if (
|
||||||
|
star_md and
|
||||||
|
platform_id in star_md.supported_platforms
|
||||||
|
and not star_md.supported_platforms[platform_id]
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||||
|
)
|
||||||
|
# 直接跳过,不添加任何消息到tool_call_result
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||||
|
)
|
||||||
|
# 尝试调用工具函数
|
||||||
|
wrapper = self._call_handler(
|
||||||
|
self.ctx, event, func_tool.handler, **func_tool_args
|
||||||
|
)
|
||||||
|
async for resp in wrapper:
|
||||||
|
if resp is not None: # 有 return 返回
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=resp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = event.get_result()
|
||||||
|
if res and res.chain:
|
||||||
|
event.set_extra("tool_call_result", res)
|
||||||
|
yield # 有生成器返回
|
||||||
|
event.clear_result() # 清除上一个 handler 的结果
|
||||||
|
except BaseException as e:
|
||||||
|
logger.warning(traceback.format_exc())
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=f"error: {str(e)}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if tool_call_result:
|
||||||
|
# 函数调用结果
|
||||||
|
req.func_tool = None # 暂时不支持递归工具调用
|
||||||
|
assistant_msg_seg = AssistantMessageSegment(
|
||||||
|
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||||
|
)
|
||||||
|
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||||
|
req.tool_calls_result = ToolCallsResult(
|
||||||
|
tool_calls_info=assistant_msg_seg,
|
||||||
|
tool_calls_result=tool_call_result,
|
||||||
|
)
|
||||||
|
yield req # 再次执行 LLM 请求
|
||||||
|
else:
|
||||||
|
if llm_response.completion_text:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(llm_response.completion_text)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _save_to_history(
|
||||||
|
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||||
|
):
|
||||||
|
if not req or not req.conversation or not llm_response:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# 文本回复
|
||||||
|
contexts = req.contexts.copy()
|
||||||
|
contexts.append(await req.assemble_context())
|
||||||
|
|
||||||
|
# 记录并标记函数调用结果
|
||||||
|
if req.tool_calls_result:
|
||||||
|
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||||
|
|
||||||
|
# 添加标记
|
||||||
|
for message in tool_calls_messages:
|
||||||
|
message["_tool_call_history"] = True
|
||||||
|
|
||||||
|
processed_tool_messages = self._process_tool_message_pairs(
|
||||||
|
tool_calls_messages, remove_tags=False
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.extend(processed_tool_messages)
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
{"role": "assistant", "content": llm_response.completion_text}
|
||||||
|
)
|
||||||
|
contexts_to_save = list(
|
||||||
|
filter(lambda item: "_no_save" not in item, contexts)
|
||||||
|
)
|
||||||
|
await self.conv_manager.update_conversation(
|
||||||
|
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||||
|
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): 消息列表
|
||||||
|
remove_tags (bool): 是否移除_tool_call_history标记
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(messages):
|
||||||
|
current_msg = messages[i]
|
||||||
|
|
||||||
|
# 普通消息直接添加
|
||||||
|
if "_tool_call_history" not in current_msg:
|
||||||
|
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 工具调用消息成对处理
|
||||||
|
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||||
|
assistant_msg = current_msg.copy()
|
||||||
|
|
||||||
|
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||||
|
del assistant_msg["_tool_call_history"]
|
||||||
|
|
||||||
|
related_tools = []
|
||||||
|
j = i + 1
|
||||||
|
while (
|
||||||
|
j < len(messages)
|
||||||
|
and messages[j].get("role") == "tool"
|
||||||
|
and "_tool_call_history" in messages[j]
|
||||||
|
):
|
||||||
|
tool_msg = messages[j].copy()
|
||||||
|
|
||||||
|
if remove_tags:
|
||||||
|
del tool_msg["_tool_call_history"]
|
||||||
|
|
||||||
|
related_tools.append(tool_msg)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# 成对的时候添加到结果
|
||||||
|
if related_tools:
|
||||||
|
result.append(assistant_msg)
|
||||||
|
result.extend(related_tools)
|
||||||
|
|
||||||
|
i = j # 跳过已处理
|
||||||
|
else:
|
||||||
|
# 单独的tool消息
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'''
|
"""
|
||||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||||
'''
|
"""
|
||||||
|
|
||||||
from ...context import PipelineContext
|
from ...context import PipelineContext
|
||||||
from ..stage import Stage
|
from ..stage import Stage
|
||||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||||
@@ -11,27 +12,44 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
|
|||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
class StarRequestSubStage(Stage):
|
|
||||||
|
|
||||||
|
class StarRequestSubStage(Stage):
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||||
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
|
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
|
||||||
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
|
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
async def process(
|
||||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
self, event: AstrMessageEvent
|
||||||
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params")
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
|
||||||
|
"activated_handlers"
|
||||||
|
)
|
||||||
|
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra(
|
||||||
|
"handlers_parsed_params"
|
||||||
|
)
|
||||||
if not handlers_parsed_params:
|
if not handlers_parsed_params:
|
||||||
handlers_parsed_params = {}
|
handlers_parsed_params = {}
|
||||||
|
|
||||||
for handler in activated_handlers:
|
for handler in activated_handlers:
|
||||||
|
# 检查处理器是否在当前平台兼容
|
||||||
|
if (
|
||||||
|
hasattr(handler, "platform_compatible")
|
||||||
|
and handler.platform_compatible is False
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||||
try:
|
try:
|
||||||
if handler.handler_module_path not in star_map:
|
if handler.handler_module_path not in star_map:
|
||||||
# 孤立无援的 star handler
|
|
||||||
continue
|
continue
|
||||||
|
logger.debug(
|
||||||
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
|
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||||
async for ret in wrapper:
|
async for ret in wrapper:
|
||||||
yield ret
|
yield ret
|
||||||
@@ -39,8 +57,11 @@ class StarRequestSubStage(Stage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||||
|
|
||||||
|
if event.is_at_or_wake_command:
|
||||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||||
event.set_result(MessageEventResult().message(ret))
|
event.set_result(MessageEventResult().message(ret))
|
||||||
yield
|
yield
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
@@ -3,15 +3,14 @@ from ..stage import Stage, register_stage
|
|||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from .method.llm_request import LLMRequestSubStage
|
from .method.llm_request import LLMRequestSubStage
|
||||||
from .method.star_request import StarRequestSubStage
|
from .method.star_request import StarRequestSubStage
|
||||||
from .method.dify_request import DifyRequestSubStage
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class ProcessStage(Stage):
|
class ProcessStage(Stage):
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.config = ctx.astrbot_config
|
self.config = ctx.astrbot_config
|
||||||
@@ -22,37 +21,48 @@ class ProcessStage(Stage):
|
|||||||
self.star_request_sub_stage = StarRequestSubStage()
|
self.star_request_sub_stage = StarRequestSubStage()
|
||||||
await self.star_request_sub_stage.initialize(ctx)
|
await self.star_request_sub_stage.initialize(ctx)
|
||||||
|
|
||||||
self.dify_request_sub_stage = DifyRequestSubStage()
|
async def process(
|
||||||
await self.dify_request_sub_stage.initialize(ctx)
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
"""处理事件"""
|
||||||
'''处理事件
|
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
|
||||||
'''
|
"activated_handlers"
|
||||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
)
|
||||||
# 有插件 Handler 被激活
|
# 有插件 Handler 被激活
|
||||||
if activated_handlers:
|
if activated_handlers:
|
||||||
async for resp in self.star_request_sub_stage.process(event):
|
async for resp in self.star_request_sub_stage.process(event):
|
||||||
# 生成器返回值处理
|
# 生成器返回值处理
|
||||||
if isinstance(resp, ProviderRequest):
|
if isinstance(resp, ProviderRequest):
|
||||||
# Handler 的 LLM 请求
|
# Handler 的 LLM 请求
|
||||||
logger.debug(f"llm request -> {resp.prompt}")
|
|
||||||
event.set_extra("provider_request", resp)
|
event.set_extra("provider_request", resp)
|
||||||
|
_t = False
|
||||||
async for _ in self.llm_request_sub_stage.process(event):
|
async for _ in self.llm_request_sub_stage.process(event):
|
||||||
|
_t = True
|
||||||
|
yield
|
||||||
|
if not _t:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 调用提供商相关请求
|
# 调用 LLM 相关请求
|
||||||
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
if not self.ctx.astrbot_config["provider_settings"].get("enable", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
if not event._has_send_oper and event.is_at_or_wake_command:
|
if (
|
||||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
not event._has_send_oper
|
||||||
|
and event.is_at_or_wake_command
|
||||||
|
and not event.call_llm
|
||||||
|
):
|
||||||
|
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||||
|
if (
|
||||||
|
event.get_result() and not event.get_result().is_stopped()
|
||||||
|
) or not event.get_result():
|
||||||
|
# 事件没有终止传播
|
||||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||||
match provider.meta().type:
|
|
||||||
case "dify":
|
if not provider:
|
||||||
async for _ in self.dify_request_sub_stage.process(event):
|
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||||
yield
|
return
|
||||||
case _:
|
|
||||||
async for _ in self.llm_request_sub_stage.process(event):
|
async for _ in self.llm_request_sub_stage.process(event):
|
||||||
yield
|
yield
|
||||||
@@ -5,7 +5,6 @@ from typing import DefaultDict, Deque, Union, AsyncGenerator
|
|||||||
from ..stage import Stage, register_stage
|
from ..stage import Stage, register_stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||||
|
|
||||||
@@ -32,11 +31,19 @@ class RateLimitStage(Stage):
|
|||||||
"""
|
"""
|
||||||
初始化限流器,根据配置设置限流参数。
|
初始化限流器,根据配置设置限流参数。
|
||||||
"""
|
"""
|
||||||
self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count']
|
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
||||||
self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time'])
|
"count"
|
||||||
self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard
|
]
|
||||||
|
self.rate_limit_time = timedelta(
|
||||||
|
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
|
||||||
|
)
|
||||||
|
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
||||||
|
"strategy"
|
||||||
|
] # stall or discard
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""
|
"""
|
||||||
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
||||||
|
|
||||||
@@ -61,20 +68,27 @@ class RateLimitStage(Stage):
|
|||||||
stall_duration = (next_window_time - now).total_seconds()
|
stall_duration = (next_window_time - now).total_seconds()
|
||||||
|
|
||||||
match self.rl_strategy:
|
match self.rl_strategy:
|
||||||
case RateLimitStrategy.STALL:
|
case RateLimitStrategy.STALL.value:
|
||||||
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
|
logger.info(
|
||||||
|
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||||
|
)
|
||||||
await asyncio.sleep(stall_duration)
|
await asyncio.sleep(stall_duration)
|
||||||
case RateLimitStrategy.DISCARD:
|
case RateLimitStrategy.DISCARD.value:
|
||||||
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||||
|
)
|
||||||
return event.stop_event()
|
return event.stop_event()
|
||||||
|
|
||||||
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
|
self._remove_expired_timestamps(
|
||||||
|
timestamps, now + timedelta(seconds=stall_duration)
|
||||||
|
)
|
||||||
|
|
||||||
timestamps.append(now)
|
timestamps.append(now)
|
||||||
|
|
||||||
return event.continue_event()
|
def _remove_expired_timestamps(
|
||||||
|
self, timestamps: Deque[datetime], now: datetime
|
||||||
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
移除时间窗口外的时间戳。
|
移除时间窗口外的时间戳。
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,218 @@
|
|||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import traceback
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage, Stage
|
from ..stage import register_stage, Stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class RespondStage(Stage):
|
class RespondStage(Stage):
|
||||||
|
# 组件类型到其非空判断函数的映射
|
||||||
|
_component_validators = {
|
||||||
|
Comp.Plain: lambda comp: bool(
|
||||||
|
comp.text and comp.text.strip()
|
||||||
|
), # 纯文本消息需要strip
|
||||||
|
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
||||||
|
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||||
|
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||||
|
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||||
|
Comp.AtAll: lambda comp: True, # @所有人
|
||||||
|
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||||
|
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||||
|
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||||
|
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||||
|
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||||
|
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||||
|
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||||
|
Comp.Music: lambda comp: bool(comp._type)
|
||||||
|
and bool(comp.url)
|
||||||
|
and bool(comp.audio), # 音乐
|
||||||
|
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||||
|
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||||
|
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||||
|
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||||
|
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||||
|
Comp.Node: lambda comp: bool(comp.name)
|
||||||
|
and comp.uin != 0
|
||||||
|
and bool(comp.content), # 一个转发节点
|
||||||
|
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||||
|
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||||
|
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||||
|
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||||
|
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||||
|
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||||
|
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||||
|
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||||
|
}
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext):
|
async def initialize(self, ctx: PipelineContext):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||||
|
"reply_with_mention"
|
||||||
|
]
|
||||||
|
self.reply_with_quote = ctx.astrbot_config["platform_settings"][
|
||||||
|
"reply_with_quote"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 分段回复
|
||||||
|
self.enable_seg: bool = ctx.astrbot_config["platform_settings"][
|
||||||
|
"segmented_reply"
|
||||||
|
]["enable"]
|
||||||
|
self.only_llm_result = ctx.astrbot_config["platform_settings"][
|
||||||
|
"segmented_reply"
|
||||||
|
]["only_llm_result"]
|
||||||
|
|
||||||
|
self.interval_method = ctx.astrbot_config["platform_settings"][
|
||||||
|
"segmented_reply"
|
||||||
|
]["interval_method"]
|
||||||
|
self.log_base = float(
|
||||||
|
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"]
|
||||||
|
)
|
||||||
|
interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||||
|
"interval"
|
||||||
|
]
|
||||||
|
interval_str_ls = interval_str.replace(" ", "").split(",")
|
||||||
|
try:
|
||||||
|
self.interval = [float(t) for t in interval_str_ls]
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"解析分段回复的间隔时间失败。{e}")
|
||||||
|
self.interval = [1.5, 3.5]
|
||||||
|
logger.info(f"分段回复间隔时间:{self.interval}")
|
||||||
|
|
||||||
|
async def _word_cnt(self, text: str) -> int:
|
||||||
|
"""分段回复 统计字数"""
|
||||||
|
if all(ord(c) < 128 for c in text):
|
||||||
|
word_count = len(text.split())
|
||||||
|
else:
|
||||||
|
word_count = len([c for c in text if c.isalnum()])
|
||||||
|
return word_count
|
||||||
|
|
||||||
|
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
||||||
|
"""分段回复 计算间隔时间"""
|
||||||
|
if self.interval_method == "log":
|
||||||
|
if isinstance(comp, Comp.Plain):
|
||||||
|
wc = await self._word_cnt(comp.text)
|
||||||
|
i = math.log(wc + 1, self.log_base)
|
||||||
|
return random.uniform(i, i + 0.5)
|
||||||
|
else:
|
||||||
|
return random.uniform(1, 1.75)
|
||||||
|
else:
|
||||||
|
# random
|
||||||
|
return random.uniform(self.interval[0], self.interval[1])
|
||||||
|
|
||||||
|
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
||||||
|
"""检查消息链是否为空
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain (list[BaseMessageComponent]): 包含消息对象的列表
|
||||||
|
"""
|
||||||
|
if not chain:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for comp in chain:
|
||||||
|
comp_type = type(comp)
|
||||||
|
|
||||||
|
# 检查组件类型是否在字典中
|
||||||
|
if comp_type in self._component_validators:
|
||||||
|
if self._component_validators[comp_type](comp):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||||
|
|
||||||
|
# 如果所有组件都为空
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
|
return
|
||||||
|
|
||||||
if len(result.chain) > 0:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
# 流式结果直接交付平台适配器处理
|
||||||
|
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||||
|
await event._pre_send()
|
||||||
|
await event.send_streaming(result.async_stream)
|
||||||
|
await event._post_send()
|
||||||
|
return
|
||||||
|
elif len(result.chain) > 0:
|
||||||
|
await event._pre_send()
|
||||||
|
|
||||||
|
# 检查消息链是否为空
|
||||||
|
try:
|
||||||
|
if await self._is_empty_message_chain(result.chain):
|
||||||
|
logger.info("消息为空,跳过发送阶段")
|
||||||
|
event.clear_result()
|
||||||
|
event.stop_event()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"空内容检查异常: {e}")
|
||||||
|
|
||||||
|
if self.enable_seg and (
|
||||||
|
(self.only_llm_result and result.is_llm_result())
|
||||||
|
or not self.only_llm_result
|
||||||
|
):
|
||||||
|
decorated_comps = []
|
||||||
|
if self.reply_with_mention:
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Comp.At):
|
||||||
|
decorated_comps.append(comp)
|
||||||
|
result.chain.remove(comp)
|
||||||
|
break
|
||||||
|
if self.reply_with_quote:
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Comp.Reply):
|
||||||
|
decorated_comps.append(comp)
|
||||||
|
result.chain.remove(comp)
|
||||||
|
break
|
||||||
|
# 分段回复
|
||||||
|
for comp in result.chain:
|
||||||
|
i = await self._calc_comp_interval(comp)
|
||||||
|
await asyncio.sleep(i)
|
||||||
|
try:
|
||||||
|
await event.send(MessageChain([*decorated_comps, comp]))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
try:
|
||||||
await event.send(result)
|
await event.send(result)
|
||||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
await event._post_send()
|
||||||
|
logger.info(
|
||||||
|
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
|
)
|
||||||
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||||
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
await handler.handler(event)
|
await handler.handler(event)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if event.is_stopped():
|
||||||
|
logger.info(
|
||||||
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
@@ -1,81 +1,258 @@
|
|||||||
import time
|
import time
|
||||||
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage
|
from ..stage import Stage, register_stage, registered_stages
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.message.message_event_result import ResultContentType
|
||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record
|
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||||
from astrbot.core import html_renderer
|
from astrbot.core import html_renderer
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class ResultDecorateStage:
|
class ResultDecorateStage(Stage):
|
||||||
async def initialize(self, ctx: PipelineContext):
|
async def initialize(self, ctx: PipelineContext):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"]
|
||||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
"reply_with_mention"
|
||||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
]
|
||||||
self.t2i = ctx.astrbot_config['t2i']
|
self.reply_with_quote = ctx.astrbot_config["platform_settings"][
|
||||||
|
"reply_with_quote"
|
||||||
|
]
|
||||||
|
self.t2i_word_threshold = ctx.astrbot_config["t2i_word_threshold"]
|
||||||
|
try:
|
||||||
|
self.t2i_word_threshold = int(self.t2i_word_threshold)
|
||||||
|
if self.t2i_word_threshold < 50:
|
||||||
|
self.t2i_word_threshold = 50
|
||||||
|
except BaseException:
|
||||||
|
self.t2i_word_threshold = 150
|
||||||
|
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
||||||
|
self.t2i_use_network = self.t2i_strategy == "remote"
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
||||||
|
"forward_threshold"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 分段回复
|
||||||
|
self.words_count_threshold = int(
|
||||||
|
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||||
|
"words_count_threshold"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.enable_segmented_reply = ctx.astrbot_config["platform_settings"][
|
||||||
|
"segmented_reply"
|
||||||
|
]["enable"]
|
||||||
|
self.only_llm_result = ctx.astrbot_config["platform_settings"][
|
||||||
|
"segmented_reply"
|
||||||
|
]["only_llm_result"]
|
||||||
|
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
|
||||||
|
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
|
||||||
|
"segmented_reply"
|
||||||
|
]["content_cleanup_rule"]
|
||||||
|
|
||||||
|
# exception
|
||||||
|
self.content_safe_check_reply = ctx.astrbot_config["content_safety"][
|
||||||
|
"also_use_in_response"
|
||||||
|
]
|
||||||
|
self.content_safe_check_stage = None
|
||||||
|
if self.content_safe_check_reply:
|
||||||
|
for stage in registered_stages:
|
||||||
|
if stage.__class__.__name__ == "ContentSafetyCheckStage":
|
||||||
|
self.content_safe_check_stage = stage
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
result = event.get_result()
|
||||||
|
if result is None or not result.chain:
|
||||||
|
return
|
||||||
|
|
||||||
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
|
||||||
|
|
||||||
|
# 回复时检查内容安全
|
||||||
|
if (
|
||||||
|
self.content_safe_check_reply
|
||||||
|
and self.content_safe_check_stage
|
||||||
|
and result.is_llm_result()
|
||||||
|
and not is_stream # 流式输出不检查内容安全
|
||||||
|
):
|
||||||
|
text = ""
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
text += comp.text
|
||||||
|
async for _ in self.content_safe_check_stage.process(
|
||||||
|
event, check_text=text
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
# 发送消息前事件钩子
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||||
|
)
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
if is_stream:
|
||||||
|
logger.warning(
|
||||||
|
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
|
||||||
|
)
|
||||||
|
await handler.handler(event)
|
||||||
|
if event.get_result() is None or not event.get_result().chain:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。"
|
||||||
|
)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if event.is_stopped():
|
||||||
|
logger.info(
|
||||||
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 流式输出不执行下面的逻辑
|
||||||
|
if is_stream:
|
||||||
|
logger.info("流式输出已启用,跳过结果装饰阶段")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
|
|
||||||
for handler in handlers:
|
|
||||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
|
||||||
await handler.handler(event)
|
|
||||||
|
|
||||||
if len(result.chain) > 0:
|
if len(result.chain) > 0:
|
||||||
# 回复前缀
|
# 回复前缀
|
||||||
if self.reply_prefix:
|
if self.reply_prefix:
|
||||||
result.chain.insert(0, Plain(self.reply_prefix))
|
|
||||||
|
|
||||||
# TTS
|
|
||||||
if self.use_tts and result.is_llm_result():
|
|
||||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
|
||||||
plain_str = ""
|
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
plain_str += " " + comp.text
|
comp.text = self.reply_prefix + comp.text
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
if plain_str:
|
|
||||||
|
# 分段回复
|
||||||
|
if self.enable_segmented_reply:
|
||||||
|
if (
|
||||||
|
self.only_llm_result and result.is_llm_result()
|
||||||
|
) or not self.only_llm_result:
|
||||||
|
new_chain = []
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
if len(comp.text) > self.words_count_threshold:
|
||||||
|
# 不分段回复
|
||||||
|
new_chain.append(comp)
|
||||||
|
continue
|
||||||
|
split_response = []
|
||||||
|
for line in comp.text.split("\n"):
|
||||||
|
split_response.extend(re.findall(self.regex, line))
|
||||||
|
if not split_response:
|
||||||
|
new_chain.append(comp)
|
||||||
|
continue
|
||||||
|
for seg in split_response:
|
||||||
|
if self.content_cleanup_rule:
|
||||||
|
seg = re.sub(self.content_cleanup_rule, "", seg)
|
||||||
|
if seg.strip():
|
||||||
|
new_chain.append(Plain(seg))
|
||||||
|
else:
|
||||||
|
# 非 Plain 类型的消息段不分段
|
||||||
|
new_chain.append(comp)
|
||||||
|
result.chain = new_chain
|
||||||
|
|
||||||
|
# TTS
|
||||||
|
if (
|
||||||
|
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||||
|
and result.is_llm_result()
|
||||||
|
):
|
||||||
|
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||||
|
new_chain = []
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||||
try:
|
try:
|
||||||
audio_path = await tts_provider.get_audio(plain_str)
|
logger.info("TTS 请求: " + comp.text)
|
||||||
|
audio_path = await tts_provider.get_audio(comp.text)
|
||||||
logger.info("TTS 结果: " + audio_path)
|
logger.info("TTS 结果: " + audio_path)
|
||||||
if audio_path:
|
if audio_path:
|
||||||
result.chain = [Record(file=audio_path, url=audio_path)]
|
new_chain.append(
|
||||||
|
Record(file=audio_path, url=audio_path)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||||
|
)
|
||||||
|
new_chain.append(comp)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
traceback.print_exc()
|
logger.error(traceback.format_exc())
|
||||||
logger.error("TTS 失败,使用文本发送。")
|
logger.error("TTS 失败,使用文本发送。")
|
||||||
|
new_chain.append(comp)
|
||||||
|
else:
|
||||||
|
new_chain.append(comp)
|
||||||
|
result.chain = new_chain
|
||||||
|
|
||||||
# 文本转图片
|
# 文本转图片
|
||||||
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
elif (
|
||||||
|
result.use_t2i_ is None and self.ctx.astrbot_config["t2i"]
|
||||||
|
) or result.use_t2i_:
|
||||||
plain_str = ""
|
plain_str = ""
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
plain_str += "\n\n" + comp.text
|
plain_str += "\n\n" + comp.text
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if plain_str and len(plain_str) > 150:
|
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
||||||
render_start = time.time()
|
render_start = time.time()
|
||||||
try:
|
try:
|
||||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
url = await html_renderer.render_t2i(
|
||||||
|
plain_str, return_url=True, use_network=self.t2i_use_network
|
||||||
|
)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error("文本转图片失败,使用文本发送。")
|
logger.error("文本转图片失败,使用文本发送。")
|
||||||
return
|
return
|
||||||
if time.time() - render_start > 3:
|
if time.time() - render_start > 3:
|
||||||
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
|
logger.warning(
|
||||||
|
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
|
||||||
|
)
|
||||||
if url:
|
if url:
|
||||||
|
if url.startswith("http"):
|
||||||
result.chain = [Image.fromURL(url)]
|
result.chain = [Image.fromURL(url)]
|
||||||
|
else:
|
||||||
|
result.chain = [Image.fromFileSystem(url)]
|
||||||
|
|
||||||
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
|
# 触发转发消息
|
||||||
result.chain.insert(0, At(qq=event.get_sender_id()))
|
has_forwarded = False
|
||||||
|
if event.get_platform_name() == "aiocqhttp":
|
||||||
|
word_cnt = 0
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
word_cnt += len(comp.text)
|
||||||
|
if word_cnt > self.forward_threshold:
|
||||||
|
node = Node(
|
||||||
|
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
|
||||||
|
)
|
||||||
|
result.chain = [node]
|
||||||
|
has_forwarded = True
|
||||||
|
|
||||||
|
if not has_forwarded:
|
||||||
|
# at 回复
|
||||||
|
if (
|
||||||
|
self.reply_with_mention
|
||||||
|
and event.get_message_type() != MessageType.FRIEND_MESSAGE
|
||||||
|
):
|
||||||
|
result.chain.insert(
|
||||||
|
0, At(qq=event.get_sender_id(), name=event.get_sender_name())
|
||||||
|
)
|
||||||
|
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
|
||||||
|
result.chain[1].text = "\n" + result.chain[1].text
|
||||||
|
|
||||||
|
# 引用回复
|
||||||
if self.reply_with_quote:
|
if self.reply_with_quote:
|
||||||
|
if not any(isinstance(item, File) for item in result.chain):
|
||||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||||
@@ -5,43 +5,74 @@ from typing import AsyncGenerator
|
|||||||
from astrbot.core.platform import AstrMessageEvent
|
from astrbot.core.platform import AstrMessageEvent
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
class PipelineScheduler():
|
|
||||||
|
class PipelineScheduler:
|
||||||
|
"""管道调度器,负责调度各个阶段的执行"""
|
||||||
|
|
||||||
def __init__(self, context: PipelineContext):
|
def __init__(self, context: PipelineContext):
|
||||||
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__ .__name__))
|
registered_stages.sort(
|
||||||
self.ctx = context
|
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||||
|
) # 按照顺序排序
|
||||||
|
self.ctx = context # 上下文对象
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
"""初始化管道调度器时, 初始化所有阶段"""
|
||||||
for stage in registered_stages:
|
for stage in registered_stages:
|
||||||
logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||||
|
|
||||||
await stage.initialize(self.ctx)
|
await stage.initialize(self.ctx)
|
||||||
|
|
||||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||||
|
"""依次执行各个阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||||
|
"""
|
||||||
for i in range(from_stage, len(registered_stages)):
|
for i in range(from_stage, len(registered_stages)):
|
||||||
stage = registered_stages[i]
|
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||||
logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||||
coro = stage.process(event)
|
coroutine = stage.process(
|
||||||
if isinstance(coro, AsyncGenerator):
|
event
|
||||||
async for _ in coro:
|
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||||
|
|
||||||
|
if isinstance(coroutine, AsyncGenerator):
|
||||||
|
# 如果返回的是异步生成器, 实现洋葱模型的核心
|
||||||
|
async for _ in coroutine:
|
||||||
|
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
|
logger.debug(
|
||||||
|
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 递归调用, 处理所有后续阶段
|
||||||
await self._process_stages(event, i + 1)
|
await self._process_stages(event, i + 1)
|
||||||
else:
|
|
||||||
await coro
|
|
||||||
|
|
||||||
|
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
|
logger.debug(
|
||||||
|
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
|
||||||
|
# 简单地等待它执行完成, 然后继续执行下一个阶段
|
||||||
|
await coroutine
|
||||||
|
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
|
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
||||||
break
|
break
|
||||||
|
|
||||||
async def execute(self, event: AstrMessageEvent):
|
async def execute(self, event: AstrMessageEvent):
|
||||||
'''执行 pipeline'''
|
"""执行 pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
"""
|
||||||
await self._process_stages(event)
|
await self._process_stages(event)
|
||||||
|
|
||||||
|
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||||
await event.send(None)
|
await event.send(None)
|
||||||
|
|
||||||
|
|||||||
@@ -1,34 +1,45 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import inspect
|
import inspect
|
||||||
|
import traceback
|
||||||
|
from astrbot.api import logger
|
||||||
from typing import List, AsyncGenerator, Union, Awaitable
|
from typing import List, AsyncGenerator, Union, Awaitable
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from .context import PipelineContext
|
from .context import PipelineContext
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||||
|
|
||||||
registered_stages: List[Stage] = []
|
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||||
'''维护了所有已注册的 Stage 实现类'''
|
|
||||||
|
|
||||||
def register_stage(cls):
|
def register_stage(cls):
|
||||||
'''一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类
|
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
|
||||||
'''
|
|
||||||
registered_stages.append(cls())
|
registered_stages.append(cls())
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
class Stage(abc.ABC):
|
class Stage(abc.ABC):
|
||||||
'''描述一个 Pipeline 的某个阶段
|
"""描述一个 Pipeline 的某个阶段"""
|
||||||
'''
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
'''初始化阶段
|
"""初始化阶段
|
||||||
'''
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
async def process(
|
||||||
'''处理事件
|
self, event: AstrMessageEvent
|
||||||
'''
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
"""处理事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象,包含事件的相关信息
|
||||||
|
Returns:
|
||||||
|
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _call_handler(
|
async def _call_handler(
|
||||||
@@ -36,30 +47,64 @@ class Stage(abc.ABC):
|
|||||||
ctx: PipelineContext,
|
ctx: PipelineContext,
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
handler: Awaitable,
|
handler: Awaitable,
|
||||||
**params
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> AsyncGenerator[None, None]:
|
) -> AsyncGenerator[None, None]:
|
||||||
'''调用 Handler。'''
|
"""执行事件处理函数并处理其返回结果
|
||||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
|
||||||
ready_to_call = None
|
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||||
|
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||||
|
2. 协程: 执行一次并处理返回值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象
|
||||||
|
event (AstrMessageEvent): 待处理的事件对象
|
||||||
|
handler (Awaitable): 事件处理函数
|
||||||
|
*args: 传递给handler的位置参数
|
||||||
|
**kwargs: 传递给handler的关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||||
|
"""
|
||||||
|
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||||
|
|
||||||
|
trace_ = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ready_to_call = handler(event, **params)
|
ready_to_call = handler(event, *args, **kwargs)
|
||||||
except TypeError as e:
|
except TypeError as _:
|
||||||
# 向下兼容
|
# 向下兼容
|
||||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
trace_ = traceback.format_exc()
|
||||||
|
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||||
|
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(ready_to_call, AsyncGenerator):
|
if isinstance(ready_to_call, AsyncGenerator):
|
||||||
|
# 如果是一个异步生成器, 进入洋葱模型
|
||||||
|
_has_yielded = False # 是否返回过值
|
||||||
|
try:
|
||||||
async for ret in ready_to_call:
|
async for ret in ready_to_call:
|
||||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||||
|
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||||
|
_has_yielded = True
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||||
event.set_result(ret)
|
event.set_result(ret)
|
||||||
yield
|
yield # 传递控制权给上一层的process函数
|
||||||
else:
|
else:
|
||||||
yield ret
|
# 如果返回值是 None, 则不设置结果并继续
|
||||||
|
# 继续执行后续阶段
|
||||||
|
yield ret # 传递控制权给上一层的process函数
|
||||||
|
if not _has_yielded:
|
||||||
|
# 如果这个异步生成器没有执行到yield分支
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Previous Error: {trace_}")
|
||||||
|
raise e
|
||||||
elif inspect.iscoroutine(ready_to_call):
|
elif inspect.iscoroutine(ready_to_call):
|
||||||
# 如果只是一个 coroutine
|
# 如果只是一个协程, 直接执行
|
||||||
ret = await ready_to_call
|
ret = await ready_to_call
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
event.set_result(ret)
|
event.set_result(ret)
|
||||||
yield
|
yield # 传递控制权给上一层的process函数
|
||||||
else:
|
else:
|
||||||
yield ret
|
yield ret # 传递控制权给上一层的process函数
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from ..stage import Stage, register_stage
|
from ..stage import Stage, register_stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
|
from astrbot import logger
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult
|
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||||
from astrbot.core.message.components import At
|
from astrbot.core.message.components import At
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
@@ -20,7 +22,19 @@ class WakingCheckStage(Stage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
"""初始化唤醒检查阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
||||||
|
"no_permission_reply", True
|
||||||
|
)
|
||||||
|
# 私聊是否需要 wake_prefix 才能唤醒机器人
|
||||||
|
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
|
||||||
|
"platform_settings"
|
||||||
|
].get("friend_message_needs_wake_prefix", False)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
@@ -28,7 +42,7 @@ class WakingCheckStage(Stage):
|
|||||||
# 设置 sender 身份
|
# 设置 sender 身份
|
||||||
event.message_str = event.message_str.strip()
|
event.message_str = event.message_str.strip()
|
||||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||||
if event.get_sender_id() == admin_id:
|
if str(event.get_sender_id()) == admin_id:
|
||||||
event.role = "admin"
|
event.role = "admin"
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -64,7 +78,7 @@ class WakingCheckStage(Stage):
|
|||||||
event.is_at_or_wake_command = True
|
event.is_at_or_wake_command = True
|
||||||
break
|
break
|
||||||
# 检查是否是私聊
|
# 检查是否是私聊
|
||||||
if event.is_private_chat():
|
if event.is_private_chat() and not self.friend_message_needs_wake_prefix:
|
||||||
is_wake = True
|
is_wake = True
|
||||||
event.is_wake = True
|
event.is_wake = True
|
||||||
event.is_at_or_wake_command = True
|
event.is_at_or_wake_command = True
|
||||||
@@ -73,44 +87,55 @@ class WakingCheckStage(Stage):
|
|||||||
# 检查插件的 handler filter
|
# 检查插件的 handler filter
|
||||||
activated_handlers = []
|
activated_handlers = []
|
||||||
handlers_parsed_params = {} # 注册了指令的 handler
|
handlers_parsed_params = {} # 注册了指令的 handler
|
||||||
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
|
|
||||||
# filter 需要满足 AND 的逻辑关系
|
|
||||||
passed = True
|
|
||||||
child_command_handler_md = None
|
|
||||||
|
|
||||||
|
for handler in star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.AdapterMessageEvent
|
||||||
|
):
|
||||||
|
# filter 需满足 AND 逻辑关系
|
||||||
|
passed = True
|
||||||
|
permission_not_pass = False
|
||||||
|
permission_filter_raise_error = False
|
||||||
if len(handler.event_filters) == 0:
|
if len(handler.event_filters) == 0:
|
||||||
# 不可能有这种情况, 也不允许有这种情况
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for filter in handler.event_filters:
|
for filter in handler.event_filters:
|
||||||
try:
|
try:
|
||||||
if isinstance(filter, CommandGroupFilter):
|
if isinstance(filter, PermissionTypeFilter):
|
||||||
"""如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata"""
|
if not filter.filter(event, self.ctx.astrbot_config):
|
||||||
ok, child_command_handler_md = filter.filter(
|
permission_not_pass = True
|
||||||
event, self.ctx.astrbot_config
|
permission_filter_raise_error = filter.raise_error
|
||||||
)
|
|
||||||
if not ok:
|
|
||||||
passed = False
|
|
||||||
else:
|
|
||||||
handler = child_command_handler_md # handler 覆盖
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
if not filter.filter(event, self.ctx.astrbot_config):
|
if not filter.filter(event, self.ctx.astrbot_config):
|
||||||
passed = False
|
passed = False
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
|
|
||||||
# yield
|
|
||||||
await event.send(
|
await event.send(
|
||||||
MessageEventResult().message(
|
MessageEventResult().message(
|
||||||
f"插件 {handler.handler_full_name} 报错:{e}"
|
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
await event._post_send()
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
passed = False
|
passed = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if passed:
|
if passed:
|
||||||
|
if permission_not_pass:
|
||||||
|
if not permission_filter_raise_error:
|
||||||
|
# 跳过
|
||||||
|
continue
|
||||||
|
if self.no_permission_reply:
|
||||||
|
await event.send(
|
||||||
|
MessageChain().message(
|
||||||
|
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await event._post_send()
|
||||||
|
logger.info(
|
||||||
|
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||||
|
)
|
||||||
|
event.stop_event()
|
||||||
|
return
|
||||||
|
|
||||||
is_wake = True
|
is_wake = True
|
||||||
event.is_wake = True
|
event.is_wake = True
|
||||||
|
|
||||||
@@ -119,6 +144,7 @@ class WakingCheckStage(Stage):
|
|||||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||||
"parsed_params"
|
"parsed_params"
|
||||||
)
|
)
|
||||||
|
|
||||||
event.clear_extra()
|
event.clear_extra()
|
||||||
|
|
||||||
event.set_extra("activated_handlers", activated_handlers)
|
event.set_extra("activated_handlers", activated_handlers)
|
||||||
|
|||||||
@@ -5,33 +5,61 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class WhitelistCheckStage(Stage):
|
class WhitelistCheckStage(Stage):
|
||||||
'''检查是否在群聊/私聊白名单
|
"""检查是否在群聊/私聊白名单"""
|
||||||
'''
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
|
||||||
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
|
|
||||||
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
|
|
||||||
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
|
|
||||||
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
|
|
||||||
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
|
|
||||||
|
|
||||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
self.enable_whitelist_check = ctx.astrbot_config["platform_settings"][
|
||||||
|
"enable_id_white_list"
|
||||||
|
]
|
||||||
|
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
|
||||||
|
self.whitelist = [
|
||||||
|
str(i).strip() for i in self.whitelist if str(i).strip() != ""
|
||||||
|
]
|
||||||
|
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
|
||||||
|
"wl_ignore_admin_on_group"
|
||||||
|
]
|
||||||
|
self.wl_ignore_admin_on_friend = ctx.astrbot_config["platform_settings"][
|
||||||
|
"wl_ignore_admin_on_friend"
|
||||||
|
]
|
||||||
|
self.wl_log = ctx.astrbot_config["platform_settings"]["id_whitelist_log"]
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
if not self.enable_whitelist_check:
|
if not self.enable_whitelist_check:
|
||||||
|
# 白名单检查未启用
|
||||||
return
|
return
|
||||||
|
|
||||||
if event.get_platform_name() == 'webchat':
|
if len(self.whitelist) == 0:
|
||||||
|
# 白名单为空,不检查
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.get_platform_name() == "webchat":
|
||||||
# WebChat 豁免
|
# WebChat 豁免
|
||||||
return
|
return
|
||||||
|
|
||||||
# 检查是否在白名单
|
# 检查是否在白名单
|
||||||
if self.wl_ignore_admin_on_group:
|
if self.wl_ignore_admin_on_group:
|
||||||
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
|
if (
|
||||||
|
event.role == "admin"
|
||||||
|
and event.get_message_type() == MessageType.GROUP_MESSAGE
|
||||||
|
):
|
||||||
return
|
return
|
||||||
if self.wl_ignore_admin_on_friend:
|
if self.wl_ignore_admin_on_friend:
|
||||||
if event.role == 'admin' and event.get_message_type() == MessageType.FRIEND_MESSAGE:
|
if (
|
||||||
|
event.role == "admin"
|
||||||
|
and event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||||
|
):
|
||||||
return
|
return
|
||||||
if event.unified_msg_origin not in self.whitelist:
|
if (
|
||||||
|
event.unified_msg_origin not in self.whitelist
|
||||||
|
and str(event.get_group_id()).strip() not in self.whitelist
|
||||||
|
):
|
||||||
if self.wl_log:
|
if self.wl_log:
|
||||||
logger.info(f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。")
|
logger.info(
|
||||||
|
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
|
||||||
|
)
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
@@ -1,4 +1,14 @@
|
|||||||
from .platform import Platform
|
from .platform import Platform
|
||||||
from .astr_message_event import AstrMessageEvent
|
from .astr_message_event import AstrMessageEvent
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
from .astrbot_message import AstrBotMessage, MessageMember, MessageType
|
from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Platform",
|
||||||
|
"AstrMessageEvent",
|
||||||
|
"PlatformMetadata",
|
||||||
|
"AstrBotMessage",
|
||||||
|
"MessageMember",
|
||||||
|
"MessageType",
|
||||||
|
"Group",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,13 +1,25 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .astrbot_message import AstrBotMessage
|
from typing import List, Union, Optional, AsyncGenerator
|
||||||
from .platform_metadata import PlatformMetadata
|
|
||||||
|
from astrbot.core.db.po import Conversation
|
||||||
|
from astrbot.core.message.components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
BaseMessageComponent,
|
||||||
|
Face,
|
||||||
|
At,
|
||||||
|
AtAll,
|
||||||
|
Forward,
|
||||||
|
Reply,
|
||||||
|
)
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from typing import List, Union
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
|
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from .astrbot_message import AstrBotMessage, Group
|
||||||
|
from .platform_metadata import PlatformMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -24,33 +36,44 @@ class MessageSesion:
|
|||||||
platform_name, message_type, session_id = session_str.split(":")
|
platform_name, message_type, session_id = session_str.split(":")
|
||||||
return MessageSesion(platform_name, MessageType(message_type), session_id)
|
return MessageSesion(platform_name, MessageType(message_type), session_id)
|
||||||
|
|
||||||
|
|
||||||
class AstrMessageEvent(abc.ABC):
|
class AstrMessageEvent(abc.ABC):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
message_str: str,
|
message_str: str,
|
||||||
message_obj: AstrBotMessage,
|
message_obj: AstrBotMessage,
|
||||||
platform_meta: PlatformMetadata,
|
platform_meta: PlatformMetadata,
|
||||||
session_id: str,):
|
session_id: str,
|
||||||
|
):
|
||||||
self.message_str = message_str
|
self.message_str = message_str
|
||||||
|
"""纯文本的消息"""
|
||||||
self.message_obj = message_obj
|
self.message_obj = message_obj
|
||||||
|
"""消息对象, AstrBotMessage。带有完整的消息结构。"""
|
||||||
self.platform_meta = platform_meta
|
self.platform_meta = platform_meta
|
||||||
|
"""消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp"""
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
|
"""用户的会话 ID。可以直接使用下面的 unified_msg_origin"""
|
||||||
self.role = "member"
|
self.role = "member"
|
||||||
self.is_wake = False # 是否通过 WakingStage
|
"""用户是否是管理员。如果是管理员,这里是 admin"""
|
||||||
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True)
|
self.is_wake = False
|
||||||
|
"""是否唤醒(是否通过 WakingStage)"""
|
||||||
|
self.is_at_or_wake_command = False
|
||||||
|
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||||
self._extras = {}
|
self._extras = {}
|
||||||
self.session = MessageSesion(
|
self.session = MessageSesion(
|
||||||
platform_name=platform_meta.name,
|
platform_name=platform_meta.name,
|
||||||
message_type=message_obj.type,
|
message_type=message_obj.type,
|
||||||
session_id=session_id
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
self.unified_msg_origin = str(self.session)
|
self.unified_msg_origin = str(self.session)
|
||||||
|
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||||
self._result: MessageEventResult = None
|
self._result: MessageEventResult = None
|
||||||
'''消息事件的结果'''
|
"""消息事件的结果"""
|
||||||
|
|
||||||
self._has_send_oper = False
|
self._has_send_oper = False
|
||||||
'''是否有过至少一次发送操作'''
|
"""在此次事件中是否有过至少一次发送消息的操作"""
|
||||||
|
self.call_llm = False
|
||||||
|
"""是否在此消息事件中禁止默认的 LLM 请求"""
|
||||||
|
|
||||||
# back_compability
|
# back_compability
|
||||||
self.platform = platform_meta
|
self.platform = platform_meta
|
||||||
@@ -58,10 +81,13 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
def get_platform_name(self):
|
def get_platform_name(self):
|
||||||
return self.platform_meta.name
|
return self.platform_meta.name
|
||||||
|
|
||||||
|
def get_platform_id(self):
|
||||||
|
return self.platform_meta.id
|
||||||
|
|
||||||
def get_message_str(self) -> str:
|
def get_message_str(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取消息字符串。
|
获取消息字符串。
|
||||||
'''
|
"""
|
||||||
return self.message_str
|
return self.message_str
|
||||||
|
|
||||||
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
|
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
|
||||||
@@ -80,107 +106,122 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
elif isinstance(i, Forward):
|
elif isinstance(i, Forward):
|
||||||
# 转发消息
|
# 转发消息
|
||||||
outline += "[转发消息]"
|
outline += "[转发消息]"
|
||||||
|
elif isinstance(i, Reply):
|
||||||
|
# 引用回复
|
||||||
|
if i.message_str:
|
||||||
|
outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
|
||||||
|
else:
|
||||||
|
outline += "[引用消息]"
|
||||||
else:
|
else:
|
||||||
outline += f"[{i.type}]"
|
outline += f"[{i.type}]"
|
||||||
|
outline += " "
|
||||||
return outline
|
return outline
|
||||||
|
|
||||||
def get_message_outline(self) -> str:
|
def get_message_outline(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取消息概要。
|
获取消息概要。
|
||||||
|
|
||||||
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
|
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
|
||||||
'''
|
"""
|
||||||
return self._outline_chain(self.message_obj.message)
|
return self._outline_chain(self.message_obj.message)
|
||||||
|
|
||||||
def get_messages(self) -> List[BaseMessageComponent]:
|
def get_messages(self) -> List[BaseMessageComponent]:
|
||||||
'''
|
"""
|
||||||
获取消息链。
|
获取消息链。
|
||||||
'''
|
"""
|
||||||
return self.message_obj.message
|
return self.message_obj.message
|
||||||
|
|
||||||
def get_message_type(self) -> MessageType:
|
def get_message_type(self) -> MessageType:
|
||||||
'''
|
"""
|
||||||
获取消息类型。
|
获取消息类型。
|
||||||
'''
|
"""
|
||||||
return self.message_obj.type
|
return self.message_obj.type
|
||||||
|
|
||||||
def get_session_id(self) -> str:
|
def get_session_id(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取会话id。
|
获取会话id。
|
||||||
'''
|
"""
|
||||||
return self.session_id
|
return self.session_id
|
||||||
|
|
||||||
def get_group_id(self) -> str:
|
def get_group_id(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取群组id。如果不是群组消息,返回空字符串。
|
获取群组id。如果不是群组消息,返回空字符串。
|
||||||
'''
|
"""
|
||||||
return self.message_obj.group_id
|
return self.message_obj.group_id
|
||||||
|
|
||||||
def get_self_id(self) -> str:
|
def get_self_id(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取机器人自身的id。
|
获取机器人自身的id。
|
||||||
'''
|
"""
|
||||||
return self.message_obj.self_id
|
return self.message_obj.self_id
|
||||||
|
|
||||||
def get_sender_id(self) -> str:
|
def get_sender_id(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取消息发送者的id。
|
获取消息发送者的id。
|
||||||
'''
|
"""
|
||||||
return self.message_obj.sender.user_id
|
return self.message_obj.sender.user_id
|
||||||
|
|
||||||
def get_sender_name(self) -> str:
|
def get_sender_name(self) -> str:
|
||||||
'''
|
"""
|
||||||
获取消息发送者的名称。(可能会返回空字符串)
|
获取消息发送者的名称。(可能会返回空字符串)
|
||||||
'''
|
"""
|
||||||
return self.message_obj.sender.nickname
|
return self.message_obj.sender.nickname
|
||||||
|
|
||||||
def set_extra(self, key, value):
|
def set_extra(self, key, value):
|
||||||
'''
|
"""
|
||||||
设置额外的信息。
|
设置额外的信息。
|
||||||
'''
|
"""
|
||||||
self._extras[key] = value
|
self._extras[key] = value
|
||||||
|
|
||||||
def get_extra(self, key = None):
|
def get_extra(self, key=None):
|
||||||
'''
|
"""
|
||||||
获取额外的信息。
|
获取额外的信息。
|
||||||
'''
|
"""
|
||||||
if key is None:
|
if key is None:
|
||||||
return self._extras
|
return self._extras
|
||||||
return self._extras.get(key, None)
|
return self._extras.get(key, None)
|
||||||
|
|
||||||
def clear_extra(self):
|
def clear_extra(self):
|
||||||
'''
|
"""
|
||||||
清除额外的信息。
|
清除额外的信息。
|
||||||
'''
|
"""
|
||||||
self._extras.clear()
|
self._extras.clear()
|
||||||
|
|
||||||
def is_private_chat(self) -> bool:
|
def is_private_chat(self) -> bool:
|
||||||
'''
|
"""
|
||||||
是否是私聊。
|
是否是私聊。
|
||||||
'''
|
"""
|
||||||
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
|
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
|
||||||
|
|
||||||
def is_wake_up(self) -> bool:
|
def is_wake_up(self) -> bool:
|
||||||
'''
|
"""
|
||||||
是否是唤醒机器人的事件。
|
是否是唤醒机器人的事件。
|
||||||
'''
|
"""
|
||||||
return self.is_wake
|
return self.is_wake
|
||||||
|
|
||||||
def is_admin(self) -> bool:
|
def is_admin(self) -> bool:
|
||||||
'''
|
"""
|
||||||
是否是管理员。
|
是否是管理员。
|
||||||
'''
|
"""
|
||||||
return self.role == "admin"
|
return self.role == "admin"
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
|
||||||
'''
|
"""发送流式消息到消息平台,使用异步生成器。
|
||||||
发送消息到消息平台。
|
目前仅支持: telegram,qq official 私聊。
|
||||||
'''
|
"""
|
||||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
|
asyncio.create_task(
|
||||||
|
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||||
|
)
|
||||||
self._has_send_oper = True
|
self._has_send_oper = True
|
||||||
|
|
||||||
|
async def _pre_send(self):
|
||||||
|
"""调度器会在执行 send() 前调用该方法"""
|
||||||
|
|
||||||
|
async def _post_send(self):
|
||||||
|
"""调度器会在执行 send() 后调用该方法"""
|
||||||
|
|
||||||
def set_result(self, result: Union[MessageEventResult, str]):
|
def set_result(self, result: Union[MessageEventResult, str]):
|
||||||
'''设置消息事件的结果。
|
"""设置消息事件的结果。
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。
|
事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。
|
||||||
@@ -199,51 +240,57 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
|
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
|
||||||
return
|
return
|
||||||
```
|
```
|
||||||
'''
|
"""
|
||||||
if isinstance(result, str):
|
if isinstance(result, str):
|
||||||
result = MessageEventResult().message(result)
|
result = MessageEventResult().message(result)
|
||||||
self._result = result
|
self._result = result
|
||||||
|
|
||||||
def stop_event(self):
|
def stop_event(self):
|
||||||
'''终止事件传播。
|
"""终止事件传播。"""
|
||||||
'''
|
|
||||||
if self._result is None:
|
if self._result is None:
|
||||||
self.set_result(MessageEventResult().stop_event())
|
self.set_result(MessageEventResult().stop_event())
|
||||||
else:
|
else:
|
||||||
self._result.stop_event()
|
self._result.stop_event()
|
||||||
|
|
||||||
def continue_event(self):
|
def continue_event(self):
|
||||||
'''继续事件传播。
|
"""继续事件传播。"""
|
||||||
'''
|
|
||||||
if self._result is None:
|
if self._result is None:
|
||||||
self.set_result(MessageEventResult().continue_event())
|
self.set_result(MessageEventResult().continue_event())
|
||||||
else:
|
else:
|
||||||
self._result.continue_event()
|
self._result.continue_event()
|
||||||
|
|
||||||
def is_stopped(self) -> bool:
|
def is_stopped(self) -> bool:
|
||||||
'''
|
"""
|
||||||
是否终止事件传播。
|
是否终止事件传播。
|
||||||
'''
|
"""
|
||||||
if self._result is None:
|
if self._result is None:
|
||||||
return False # 默认是继续传播
|
return False # 默认是继续传播
|
||||||
return self._result.is_stopped()
|
return self._result.is_stopped()
|
||||||
|
|
||||||
|
def should_call_llm(self, call_llm: bool):
|
||||||
|
"""
|
||||||
|
是否在此消息事件中禁止默认的 LLM 请求。
|
||||||
|
|
||||||
|
只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。
|
||||||
|
"""
|
||||||
|
self.call_llm = call_llm
|
||||||
|
|
||||||
def get_result(self) -> MessageEventResult:
|
def get_result(self) -> MessageEventResult:
|
||||||
'''
|
"""
|
||||||
获取消息事件的结果。
|
获取消息事件的结果。
|
||||||
'''
|
"""
|
||||||
return self._result
|
return self._result
|
||||||
|
|
||||||
def clear_result(self):
|
def clear_result(self):
|
||||||
'''
|
"""
|
||||||
清除消息事件的结果。
|
清除消息事件的结果。
|
||||||
'''
|
"""
|
||||||
self._result = None
|
self._result = None
|
||||||
|
|
||||||
'''消息链相关'''
|
"""消息链相关"""
|
||||||
|
|
||||||
def make_result(self) -> MessageEventResult:
|
def make_result(self) -> MessageEventResult:
|
||||||
'''
|
"""
|
||||||
创建一个空的消息事件结果。
|
创建一个空的消息事件结果。
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -255,58 +302,99 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
yield event.make_result().url_image("https://example.com/image.jpg")
|
yield event.make_result().url_image("https://example.com/image.jpg")
|
||||||
yield event.make_result().file_image("image.jpg")
|
yield event.make_result().file_image("image.jpg")
|
||||||
```
|
```
|
||||||
'''
|
"""
|
||||||
return MessageEventResult()
|
return MessageEventResult()
|
||||||
|
|
||||||
def plain_result(self, text: str) -> MessageEventResult:
|
def plain_result(self, text: str) -> MessageEventResult:
|
||||||
'''
|
"""
|
||||||
创建一个空的消息事件结果,只包含一条文本消息。
|
创建一个空的消息事件结果,只包含一条文本消息。
|
||||||
'''
|
"""
|
||||||
return MessageEventResult().message(text)
|
return MessageEventResult().message(text)
|
||||||
|
|
||||||
def image_result(self, url_or_path: str) -> MessageEventResult:
|
def image_result(self, url_or_path: str) -> MessageEventResult:
|
||||||
'''
|
"""
|
||||||
创建一个空的消息事件结果,只包含一条图片消息。
|
创建一个空的消息事件结果,只包含一条图片消息。
|
||||||
|
|
||||||
根据开头是否包含 http 来判断是网络图片还是本地图片。
|
根据开头是否包含 http 来判断是网络图片还是本地图片。
|
||||||
'''
|
"""
|
||||||
if url_or_path.startswith("http"):
|
if url_or_path.startswith("http"):
|
||||||
return MessageEventResult().url_image(url_or_path)
|
return MessageEventResult().url_image(url_or_path)
|
||||||
return MessageEventResult().file_image(url_or_path)
|
return MessageEventResult().file_image(url_or_path)
|
||||||
|
|
||||||
def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult:
|
def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult:
|
||||||
'''
|
"""
|
||||||
创建一个空的消息事件结果,包含指定的消息链。
|
创建一个空的消息事件结果,包含指定的消息链。
|
||||||
'''
|
"""
|
||||||
mer = MessageEventResult()
|
mer = MessageEventResult()
|
||||||
mer.chain = chain
|
mer.chain = chain
|
||||||
return mer
|
return mer
|
||||||
|
|
||||||
'''LLM 请求相关'''
|
"""LLM 请求相关"""
|
||||||
|
|
||||||
def request_llm(
|
def request_llm(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
func_tool_manager=None,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
image_urls: List[str] = None,
|
image_urls: List[str] = [],
|
||||||
contexts: List = None,
|
contexts: List = [],
|
||||||
system_prompt: str = ""
|
system_prompt: str = "",
|
||||||
|
conversation: Conversation = None,
|
||||||
) -> ProviderRequest:
|
) -> ProviderRequest:
|
||||||
'''
|
"""
|
||||||
创建一个 LLM 请求。
|
创建一个 LLM 请求。
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
```py
|
```py
|
||||||
yield event.request_llm(prompt="hi")
|
yield event.request_llm(prompt="hi")
|
||||||
```
|
```
|
||||||
|
prompt: 提示词
|
||||||
|
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
|
||||||
|
session_id: 已经过时,留空即可
|
||||||
|
|
||||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||||
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
|
|
||||||
'''
|
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。如果同时传入了 conversation,将会忽略 conversation。
|
||||||
|
|
||||||
|
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||||
|
|
||||||
|
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(contexts) > 0 and conversation:
|
||||||
|
conversation = None
|
||||||
|
|
||||||
return ProviderRequest(
|
return ProviderRequest(
|
||||||
prompt = prompt,
|
prompt=prompt,
|
||||||
session_id = session_id,
|
session_id=session_id,
|
||||||
image_urls = image_urls,
|
image_urls=image_urls,
|
||||||
contexts = contexts,
|
func_tool=func_tool_manager,
|
||||||
system_prompt = system_prompt
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
conversation=conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""平台适配器"""
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
"""发送消息到消息平台。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (MessageChain): 消息链,具体使用方式请参考文档。
|
||||||
|
"""
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||||
|
)
|
||||||
|
self._has_send_oper = True
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||||
|
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||||
|
|
||||||
|
适配情况:
|
||||||
|
|
||||||
|
- gewechat
|
||||||
|
- aiocqhttp(OneBotv11)
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -4,15 +4,53 @@ from dataclasses import dataclass
|
|||||||
from astrbot.core.message.components import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent
|
||||||
from .message_type import MessageType
|
from .message_type import MessageType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageMember():
|
class MessageMember:
|
||||||
user_id: str # 发送者id
|
user_id: str # 发送者id
|
||||||
nickname: str = None
|
nickname: str = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# 使用 f-string 来构建返回的字符串表示形式
|
||||||
|
return (
|
||||||
|
f"User ID: {self.user_id},"
|
||||||
|
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Group:
|
||||||
|
group_id: str
|
||||||
|
"""群号"""
|
||||||
|
group_name: str = None
|
||||||
|
"""群名称"""
|
||||||
|
group_avatar: str = None
|
||||||
|
"""群头像"""
|
||||||
|
group_owner: str = None
|
||||||
|
"""群主 id"""
|
||||||
|
group_admins: List[str] = None
|
||||||
|
"""群管理员 id"""
|
||||||
|
members: List[MessageMember] = None
|
||||||
|
"""所有群成员"""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# 使用 f-string 来构建返回的字符串表示形式
|
||||||
|
return (
|
||||||
|
f"Group ID: {self.group_id}\n"
|
||||||
|
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
|
||||||
|
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
|
||||||
|
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
|
||||||
|
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
|
||||||
|
f"Members Len: {len(self.members) if self.members else 0}\n"
|
||||||
|
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AstrBotMessage:
|
class AstrBotMessage:
|
||||||
'''
|
"""
|
||||||
AstrBot 的消息对象
|
AstrBot 的消息对象
|
||||||
'''
|
"""
|
||||||
|
|
||||||
type: MessageType # 消息类型
|
type: MessageType # 消息类型
|
||||||
self_id: str # 机器人的识别id
|
self_id: str # 机器人的识别id
|
||||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import traceback
|
||||||
|
import asyncio
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from .platform import Platform
|
from .platform import Platform
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -6,42 +8,147 @@ from .register import platform_cls_map
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||||
|
|
||||||
class PlatformManager():
|
|
||||||
|
class PlatformManager:
|
||||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||||
self.platform_insts: List[Platform] = []
|
self.platform_insts: List[Platform] = []
|
||||||
'''加载的 Platform 的实例'''
|
"""加载的 Platform 的实例"""
|
||||||
|
|
||||||
self.platforms_config = config['platform']
|
self._inst_map = {}
|
||||||
self.settings = config['platform_settings']
|
|
||||||
|
self.platforms_config = config["platform"]
|
||||||
|
self.settings = config["platform_settings"]
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue
|
||||||
|
|
||||||
for platform in self.platforms_config:
|
|
||||||
if not platform['enable']:
|
|
||||||
continue
|
|
||||||
match platform['type']:
|
|
||||||
case "aiocqhttp":
|
|
||||||
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
|
|
||||||
case "qq_official":
|
|
||||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
|
||||||
case "vchat":
|
|
||||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
|
||||||
case "gewechat":
|
|
||||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
"""初始化所有平台适配器"""
|
||||||
for platform in self.platforms_config:
|
for platform in self.platforms_config:
|
||||||
if not platform['enable']:
|
try:
|
||||||
continue
|
await self.load_platform(platform)
|
||||||
if platform['type'] not in platform_cls_map:
|
except Exception as e:
|
||||||
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
|
||||||
continue
|
|
||||||
cls_type = platform_cls_map[platform['type']]
|
# 网页聊天
|
||||||
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
|
||||||
inst = cls_type(platform, self.settings, self.event_queue)
|
self.platform_insts.append(webchat_inst)
|
||||||
|
asyncio.create_task(
|
||||||
|
self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat"))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_platform(self, platform_config: dict):
|
||||||
|
"""实例化一个平台"""
|
||||||
|
# 动态导入
|
||||||
|
try:
|
||||||
|
if not platform_config["enable"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ..."
|
||||||
|
)
|
||||||
|
match platform_config["type"]:
|
||||||
|
case "aiocqhttp":
|
||||||
|
from .sources.aiocqhttp.aiocqhttp_platform_adapter import (
|
||||||
|
AiocqhttpAdapter, # noqa: F401
|
||||||
|
)
|
||||||
|
case "qq_official":
|
||||||
|
from .sources.qqofficial.qqofficial_platform_adapter import (
|
||||||
|
QQOfficialPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
|
case "qq_official_webhook":
|
||||||
|
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||||
|
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
|
case "gewechat":
|
||||||
|
from .sources.gewechat.gewechat_platform_adapter import (
|
||||||
|
GewechatPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
|
case "lark":
|
||||||
|
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||||
|
case "dingtalk":
|
||||||
|
from .sources.dingtalk.dingtalk_adapter import (
|
||||||
|
DingtalkPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
|
case "telegram":
|
||||||
|
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||||
|
case "wecom":
|
||||||
|
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
|
||||||
|
|
||||||
|
if platform_config["type"] not in platform_cls_map:
|
||||||
|
logger.error(
|
||||||
|
f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
cls_type = platform_cls_map[platform_config["type"]]
|
||||||
|
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
|
||||||
|
self._inst_map[platform_config["id"]] = {
|
||||||
|
"inst": inst,
|
||||||
|
"client_id": inst.client_self_id,
|
||||||
|
}
|
||||||
self.platform_insts.append(inst)
|
self.platform_insts.append(inst)
|
||||||
|
|
||||||
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
|
asyncio.create_task(
|
||||||
|
self._task_wrapper(
|
||||||
|
asyncio.create_task(
|
||||||
|
inst.run(),
|
||||||
|
name=f"platform_{platform_config['type']}_{platform_config['id']}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _task_wrapper(self, task: asyncio.Task):
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||||
|
for line in traceback.format_exc().split("\n"):
|
||||||
|
logger.error(f"| {line}")
|
||||||
|
logger.error("-------")
|
||||||
|
|
||||||
|
async def reload(self, platform_config: dict):
|
||||||
|
await self.terminate_platform(platform_config["id"])
|
||||||
|
if platform_config["enable"]:
|
||||||
|
await self.load_platform(platform_config)
|
||||||
|
|
||||||
|
# 和配置文件保持同步
|
||||||
|
config_ids = [provider["id"] for provider in self.platforms_config]
|
||||||
|
for key in list(self._inst_map.keys()):
|
||||||
|
if key not in config_ids:
|
||||||
|
await self.terminate_platform(key)
|
||||||
|
|
||||||
|
async def terminate_platform(self, platform_id: str):
|
||||||
|
if platform_id in self._inst_map:
|
||||||
|
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||||
|
|
||||||
|
# client_id = self._inst_map.pop(platform_id, None)
|
||||||
|
info = self._inst_map.pop(platform_id, None)
|
||||||
|
client_id = info["client_id"]
|
||||||
|
inst = info["inst"]
|
||||||
|
try:
|
||||||
|
self.platform_insts.remove(
|
||||||
|
next(
|
||||||
|
inst
|
||||||
|
for inst in self.platform_insts
|
||||||
|
if inst.client_self_id == client_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
|
||||||
|
|
||||||
|
if getattr(inst, "terminate", None):
|
||||||
|
await inst.terminate()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
for inst in self.platform_insts:
|
||||||
|
if getattr(inst, "terminate", None):
|
||||||
|
await inst.terminate()
|
||||||
|
|
||||||
def get_insts(self):
|
def get_insts(self):
|
||||||
return self.platform_insts
|
return self.platform_insts
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class MessageType(Enum):
|
class MessageType(Enum):
|
||||||
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
|
GROUP_MESSAGE = "GroupMessage" # 群组形式的消息
|
||||||
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
|
FRIEND_MESSAGE = "FriendMessage" # 私聊、好友等单聊消息
|
||||||
OTHER_MESSAGE = 'OtherMessage' # 其他类型的消息,如系统消息等
|
OTHER_MESSAGE = "OtherMessage" # 其他类型的消息,如系统消息等
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import uuid
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
@@ -7,36 +8,52 @@ from astrbot.core.message.message_event_result import MessageChain
|
|||||||
from .astr_message_event import MessageSesion
|
from .astr_message_event import MessageSesion
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
|
|
||||||
|
|
||||||
class Platform(abc.ABC):
|
class Platform(abc.ABC):
|
||||||
def __init__(self, event_queue: Queue):
|
def __init__(self, event_queue: Queue):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||||
self._event_queue = event_queue
|
self._event_queue = event_queue
|
||||||
|
self.client_self_id = uuid.uuid4().hex
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self) -> Awaitable[Any]:
|
def run(self) -> Awaitable[Any]:
|
||||||
'''
|
"""
|
||||||
得到一个平台的运行实例,需要返回一个协程对象。
|
得到一个平台的运行实例,需要返回一个协程对象。
|
||||||
'''
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""
|
||||||
|
终止一个平台的运行实例。
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
'''
|
"""
|
||||||
得到一个平台的元数据。
|
得到一个平台的元数据。
|
||||||
'''
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain) -> Awaitable[Any]:
|
async def send_by_session(
|
||||||
'''
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
) -> Awaitable[Any]:
|
||||||
|
"""
|
||||||
通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||||
|
|
||||||
异步方法。
|
异步方法。
|
||||||
'''
|
"""
|
||||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.meta().name)
|
await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name)
|
||||||
|
|
||||||
def commit_event(self, event: AstrMessageEvent):
|
def commit_event(self, event: AstrMessageEvent):
|
||||||
'''
|
"""
|
||||||
提交一个事件到事件队列。
|
提交一个事件到事件队列。
|
||||||
'''
|
"""
|
||||||
self._event_queue.put_nowait(event)
|
self._event_queue.put_nowait(event)
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
"""
|
||||||
|
获取平台的客户端对象。
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlatformMetadata():
|
class PlatformMetadata:
|
||||||
name: str
|
name: str
|
||||||
'''平台的名称'''
|
"""平台的名称"""
|
||||||
description: str
|
description: str
|
||||||
'''平台的描述'''
|
"""平台的描述"""
|
||||||
|
id: str = None
|
||||||
|
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||||
|
|
||||||
default_config_tmpl: dict = None
|
default_config_tmpl: dict = None
|
||||||
'''平台的默认配置模板'''
|
"""平台的默认配置模板"""
|
||||||
adapter_display_name: str = None
|
adapter_display_name: str = None
|
||||||
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''
|
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||||
|
|||||||
@@ -3,38 +3,42 @@ from .platform_metadata import PlatformMetadata
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
platform_registry: List[PlatformMetadata] = []
|
platform_registry: List[PlatformMetadata] = []
|
||||||
'''维护了通过装饰器注册的平台适配器'''
|
"""维护了通过装饰器注册的平台适配器"""
|
||||||
platform_cls_map: Dict[str, Type] = {}
|
platform_cls_map: Dict[str, Type] = {}
|
||||||
'''维护了平台适配器名称和适配器类的映射'''
|
"""维护了平台适配器名称和适配器类的映射"""
|
||||||
|
|
||||||
|
|
||||||
def register_platform_adapter(
|
def register_platform_adapter(
|
||||||
adapter_name: str,
|
adapter_name: str,
|
||||||
desc: str,
|
desc: str,
|
||||||
default_config_tmpl: dict = None,
|
default_config_tmpl: dict = None,
|
||||||
adapter_display_name: str = None
|
adapter_display_name: str = None,
|
||||||
):
|
):
|
||||||
'''用于注册平台适配器的带参装饰器。
|
"""用于注册平台适配器的带参装饰器。
|
||||||
|
|
||||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
if adapter_name in platform_cls_map:
|
if adapter_name in platform_cls_map:
|
||||||
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
|
raise ValueError(
|
||||||
|
f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。"
|
||||||
|
)
|
||||||
|
|
||||||
# 添加必备选项
|
# 添加必备选项
|
||||||
if default_config_tmpl:
|
if default_config_tmpl:
|
||||||
if 'type' not in default_config_tmpl:
|
if "type" not in default_config_tmpl:
|
||||||
default_config_tmpl['type'] = adapter_name
|
default_config_tmpl["type"] = adapter_name
|
||||||
if 'enable' not in default_config_tmpl:
|
if "enable" not in default_config_tmpl:
|
||||||
default_config_tmpl['enable'] = False
|
default_config_tmpl["enable"] = False
|
||||||
if 'id' not in default_config_tmpl:
|
if "id" not in default_config_tmpl:
|
||||||
default_config_tmpl['id'] = adapter_name
|
default_config_tmpl["id"] = adapter_name
|
||||||
|
|
||||||
pm = PlatformMetadata(
|
pm = PlatformMetadata(
|
||||||
name=adapter_name,
|
name=adapter_name,
|
||||||
description=desc,
|
description=desc,
|
||||||
default_config_tmpl=default_config_tmpl,
|
default_config_tmpl=default_config_tmpl,
|
||||||
adapter_display_name=adapter_display_name
|
adapter_display_name=adapter_display_name,
|
||||||
)
|
)
|
||||||
platform_registry.append(pm)
|
platform_registry.append(pm)
|
||||||
platform_cls_map[adapter_name] = cls
|
platform_cls_map[adapter_name] = cls
|
||||||
|
|||||||
@@ -1,50 +1,139 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import typing
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image, Record
|
from astrbot.api.platform import Group, MessageMember
|
||||||
|
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
||||||
from aiocqhttp import CQHttp
|
from aiocqhttp import CQHttp
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
|
||||||
|
|
||||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||||
def __init__(self, message_str, message_obj, platform_meta, session_id, bot: CQHttp):
|
def __init__(
|
||||||
|
self, message_str, message_obj, platform_meta, session_id, bot: CQHttp
|
||||||
|
):
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _parse_onebot_json(message_chain: MessageChain):
|
async def _parse_onebot_json(message_chain: MessageChain):
|
||||||
'''解析成 OneBot json 格式'''
|
"""解析成 OneBot json 格式"""
|
||||||
ret = []
|
ret = []
|
||||||
for segment in message_chain.chain:
|
for segment in message_chain.chain:
|
||||||
d = segment.toDict()
|
d = segment.toDict()
|
||||||
if isinstance(segment, Plain):
|
if isinstance(segment, Plain):
|
||||||
d['type'] = 'text'
|
d["type"] = "text"
|
||||||
if isinstance(segment, (Image, Record)):
|
d["data"]["text"] = segment.text.strip()
|
||||||
|
# 如果是空文本或者只带换行符的文本,不发送
|
||||||
|
if not d["data"]["text"]:
|
||||||
|
continue
|
||||||
|
elif isinstance(segment, (Image, Record)):
|
||||||
# convert to base64
|
# convert to base64
|
||||||
if segment.file and segment.file.startswith("file:///"):
|
bs64 = await segment.convert_to_base64()
|
||||||
bs64_data = file_to_base64(segment.file[8:])
|
d["data"] = {
|
||||||
image_file_path = segment.file[8:]
|
"file": bs64,
|
||||||
elif segment.file and segment.file.startswith("http"):
|
}
|
||||||
image_file_path = await download_image_by_url(segment.file)
|
elif isinstance(segment, At):
|
||||||
bs64_data = file_to_base64(image_file_path)
|
d["data"] = {
|
||||||
else:
|
"qq": str(segment.qq) # 转换为字符串
|
||||||
bs64_data = file_to_base64(segment.file)
|
|
||||||
d['data'] = {
|
|
||||||
'file': bs64_data,
|
|
||||||
}
|
}
|
||||||
ret.append(d)
|
ret.append(d)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
|
||||||
|
if not ret:
|
||||||
return
|
return
|
||||||
|
|
||||||
if message.is_split_: # 分条发送
|
send_one_by_one = False
|
||||||
for m in ret:
|
for seg in message.chain:
|
||||||
await self.bot.send(self.message_obj.raw_message, [m])
|
if isinstance(seg, (Node, Nodes)):
|
||||||
await asyncio.sleep(random.uniform(0.75, 2.5))
|
# 转发消息不能和普通消息混在一起发送
|
||||||
|
send_one_by_one = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if send_one_by_one:
|
||||||
|
for seg in message.chain:
|
||||||
|
if isinstance(seg, (Node, Nodes)):
|
||||||
|
# 合并转发消息
|
||||||
|
|
||||||
|
if isinstance(seg, Node):
|
||||||
|
nodes = Nodes([seg])
|
||||||
|
seg = nodes
|
||||||
|
|
||||||
|
payload = seg.toDict()
|
||||||
|
if self.get_group_id():
|
||||||
|
payload["group_id"] = self.get_group_id()
|
||||||
|
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||||
|
else:
|
||||||
|
payload["user_id"] = self.get_sender_id()
|
||||||
|
await self.bot.call_action(
|
||||||
|
"send_private_forward_msg", **payload
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.bot.send(
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
await AiocqhttpMessageEvent._parse_onebot_json(
|
||||||
|
MessageChain([seg])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
else:
|
else:
|
||||||
await self.bot.send(self.message_obj.raw_message, ret)
|
await self.bot.send(self.message_obj.raw_message, ret)
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator)
|
||||||
|
|
||||||
|
async def get_group(self, group_id=None, **kwargs):
|
||||||
|
if isinstance(group_id, str) and group_id.isdigit():
|
||||||
|
group_id = int(group_id)
|
||||||
|
elif self.get_group_id():
|
||||||
|
group_id = int(self.get_group_id())
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
info: dict = await self.bot.call_action(
|
||||||
|
"get_group_info",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
members: typing.List[typing.Dict] = await self.bot.call_action(
|
||||||
|
"get_group_member_list",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
owner_id = None
|
||||||
|
admin_ids = []
|
||||||
|
for member in members:
|
||||||
|
if member["role"] == "owner":
|
||||||
|
owner_id = member["user_id"]
|
||||||
|
if member["role"] == "admin":
|
||||||
|
admin_ids.append(member["user_id"])
|
||||||
|
|
||||||
|
group = Group(
|
||||||
|
group_id=str(group_id),
|
||||||
|
group_name=info.get("group_name"),
|
||||||
|
group_avatar="",
|
||||||
|
group_admins=admin_ids,
|
||||||
|
group_owner=str(owner_id),
|
||||||
|
members=[
|
||||||
|
MessageMember(
|
||||||
|
user_id=member["user_id"],
|
||||||
|
nickname=member.get("nickname") or member.get("card"),
|
||||||
|
)
|
||||||
|
for member in members
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return group
|
||||||
|
|||||||
@@ -2,9 +2,16 @@ import os
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any
|
||||||
from aiocqhttp import CQHttp, Event
|
from aiocqhttp import CQHttp, Event
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from .aiocqhttp_message_event import * # noqa: F403
|
from .aiocqhttp_message_event import * # noqa: F403
|
||||||
from astrbot.api.message_components import * # noqa: F403
|
from astrbot.api.message_components import * # noqa: F403
|
||||||
@@ -15,23 +22,63 @@ from ...register import register_platform_adapter
|
|||||||
from aiocqhttp.exceptions import ActionFailed
|
from aiocqhttp.exceptions import ActionFailed
|
||||||
from astrbot.core.utils.io import download_file
|
from astrbot.core.utils.io import download_file
|
||||||
|
|
||||||
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
|
|
||||||
|
@register_platform_adapter(
|
||||||
|
"aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。"
|
||||||
|
)
|
||||||
class AiocqhttpAdapter(Platform):
|
class AiocqhttpAdapter(Platform):
|
||||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
super().__init__(event_queue)
|
super().__init__(event_queue)
|
||||||
|
|
||||||
self.config = platform_config
|
self.config = platform_config
|
||||||
self.settings = platform_settings
|
self.settings = platform_settings
|
||||||
self.unique_session = platform_settings['unique_session']
|
self.unique_session = platform_settings["unique_session"]
|
||||||
self.host = platform_config['ws_reverse_host']
|
self.host = platform_config["ws_reverse_host"]
|
||||||
self.port = platform_config['ws_reverse_port']
|
self.port = platform_config["ws_reverse_port"]
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
"aiocqhttp",
|
name="aiocqhttp",
|
||||||
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
self.bot = CQHttp(
|
||||||
|
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.bot.on_request()
|
||||||
|
async def request(event: Event):
|
||||||
|
abm = await self.convert_message(event)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
@self.bot.on_notice()
|
||||||
|
async def notice(event: Event):
|
||||||
|
abm = await self.convert_message(event)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
@self.bot.on_message("group")
|
||||||
|
async def group(event: Event):
|
||||||
|
abm = await self.convert_message(event)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
@self.bot.on_message("private")
|
||||||
|
async def private(event: Event):
|
||||||
|
abm = await self.convert_message(event)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
@self.bot.on_websocket_connection
|
||||||
|
def on_websocket_connection(_):
|
||||||
|
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
||||||
match session.message_type.value:
|
match session.message_type.value:
|
||||||
case MessageType.GROUP_MESSAGE.value:
|
case MessageType.GROUP_MESSAGE.value:
|
||||||
@@ -40,28 +87,106 @@ class AiocqhttpAdapter(Platform):
|
|||||||
_, group_id = session.session_id.split("_")
|
_, group_id = session.session_id.split("_")
|
||||||
await self.bot.send_group_msg(group_id=group_id, message=ret)
|
await self.bot.send_group_msg(group_id=group_id, message=ret)
|
||||||
else:
|
else:
|
||||||
await self.bot.send_group_msg(group_id=session.session_id, message=ret)
|
await self.bot.send_group_msg(
|
||||||
|
group_id=session.session_id, message=ret
|
||||||
|
)
|
||||||
case MessageType.FRIEND_MESSAGE.value:
|
case MessageType.FRIEND_MESSAGE.value:
|
||||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||||
await super().send_by_session(session, message_chain)
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||||
|
logger.debug(f"[aiocqhttp] RawMessage {event}")
|
||||||
|
|
||||||
|
if event["post_type"] == "message":
|
||||||
|
abm = await self._convert_handle_message_event(event)
|
||||||
|
elif event["post_type"] == "notice":
|
||||||
|
abm = await self._convert_handle_notice_event(event)
|
||||||
|
elif event["post_type"] == "request":
|
||||||
|
abm = await self._convert_handle_request_event(event)
|
||||||
|
|
||||||
|
return abm
|
||||||
|
|
||||||
|
async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
|
||||||
|
"""OneBot V11 请求类事件"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.tag = "aiocqhttp"
|
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||||
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname'])
|
if "group_id" in event and event["group_id"]:
|
||||||
|
|
||||||
if event['message_type'] == 'group':
|
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
elif event['message_type'] == 'private':
|
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
|
||||||
|
|
||||||
if self.unique_session:
|
|
||||||
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
|
|
||||||
else:
|
else:
|
||||||
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.message = []
|
||||||
|
abm.timestamp = int(time.time())
|
||||||
|
abm.message_id = uuid.uuid4().hex
|
||||||
|
abm.raw_message = event
|
||||||
|
return abm
|
||||||
|
|
||||||
|
async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
|
||||||
|
"""OneBot V11 通知类事件"""
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.self_id = str(event.self_id)
|
||||||
|
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||||
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
|
if "group_id" in event and event["group_id"]:
|
||||||
|
abm.group_id = str(event.group_id)
|
||||||
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
|
else:
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = (
|
||||||
|
str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||||
|
) # 也保留群组 id
|
||||||
|
else:
|
||||||
|
abm.session_id = (
|
||||||
|
str(event.group_id)
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE
|
||||||
|
else abm.sender.user_id
|
||||||
|
)
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.message = []
|
||||||
|
abm.raw_message = event
|
||||||
|
abm.timestamp = int(time.time())
|
||||||
|
abm.message_id = uuid.uuid4().hex
|
||||||
|
|
||||||
|
if "sub_type" in event:
|
||||||
|
if event["sub_type"] == "poke" and "target_id" in event:
|
||||||
|
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
||||||
|
|
||||||
|
return abm
|
||||||
|
|
||||||
|
async def _convert_handle_message_event(
|
||||||
|
self, event: Event, get_reply=True
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""OneBot V11 消息类事件
|
||||||
|
|
||||||
|
@param event: 事件对象
|
||||||
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
|
"""
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.self_id = str(event.self_id)
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
str(event.sender["user_id"]), event.sender["nickname"]
|
||||||
|
)
|
||||||
|
if event["message_type"] == "group":
|
||||||
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
|
abm.group_id = str(event.group_id)
|
||||||
|
elif event["message_type"] == "private":
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = (
|
||||||
|
abm.sender.user_id + "_" + str(event.group_id)
|
||||||
|
) # 也保留群组 id
|
||||||
|
else:
|
||||||
|
abm.session_id = (
|
||||||
|
str(event.group_id)
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE
|
||||||
|
else abm.sender.user_id
|
||||||
|
)
|
||||||
|
|
||||||
abm.message_id = str(event.message_id)
|
abm.message_id = str(event.message_id)
|
||||||
abm.message = []
|
abm.message = []
|
||||||
@@ -75,95 +200,131 @@ class AiocqhttpAdapter(Platform):
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"回复消息失败: {e}")
|
logger.error(f"回复消息失败: {e}")
|
||||||
return
|
return
|
||||||
logger.debug(f"aiocqhttp: 收到消息: {event.message}")
|
|
||||||
|
# 按消息段类型类型适配
|
||||||
for m in event.message:
|
for m in event.message:
|
||||||
t = m['type']
|
t = m["type"]
|
||||||
a = None
|
a = None
|
||||||
if t == 'text':
|
if t == "text":
|
||||||
message_str += m['data']['text'].strip()
|
message_str += m["data"]["text"].strip()
|
||||||
elif t == 'file':
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
if m['data']['url'] and m['data']['url'].startswith("http"):
|
abm.message.append(a)
|
||||||
|
|
||||||
|
elif t == "file":
|
||||||
|
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||||
# Lagrange
|
# Lagrange
|
||||||
logger.info("guessing lagrange")
|
logger.info("guessing lagrange")
|
||||||
|
|
||||||
file_name = m['data'].get('file_name', "file")
|
file_name = m["data"].get("file_name", "file")
|
||||||
path = os.path.join("data/temp", file_name)
|
path = os.path.join("data/temp", file_name)
|
||||||
await download_file(m['data']['url'], path)
|
await download_file(m["data"]["url"], path)
|
||||||
|
|
||||||
m['data'] = {
|
m["data"] = {"file": path, "name": file_name}
|
||||||
"file": path,
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
"name": file_name
|
abm.message.append(a)
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# Napcat, LLBot
|
# Napcat, LLBot
|
||||||
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
|
ret = await self.bot.call_action(
|
||||||
if not ret.get('file', None):
|
action="get_file",
|
||||||
|
file_id=event.message[0]["data"]["file_id"],
|
||||||
|
)
|
||||||
|
if not ret.get("file", None):
|
||||||
raise ValueError(f"无法解析文件响应: {ret}")
|
raise ValueError(f"无法解析文件响应: {ret}")
|
||||||
if not os.path.exists(ret['file']):
|
if not os.path.exists(ret["file"]):
|
||||||
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
|
raise FileNotFoundError(
|
||||||
|
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
|
||||||
|
)
|
||||||
|
|
||||||
m['data'] = {
|
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
|
||||||
"file": ret['file'],
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
"name": ret['file_name']
|
abm.message.append(a)
|
||||||
}
|
|
||||||
except ActionFailed as e:
|
except ActionFailed as e:
|
||||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||||
|
|
||||||
a = ComponentTypes[t](**m['data']) # noqa: F405
|
elif t == "reply":
|
||||||
|
if not get_reply:
|
||||||
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
abm.message.append(a)
|
abm.message.append(a)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
reply_event_data = await self.bot.call_action(
|
||||||
|
action="get_msg",
|
||||||
|
message_id=int(m["data"]["id"]),
|
||||||
|
)
|
||||||
|
abm_reply = await self._convert_handle_message_event(
|
||||||
|
Event.from_payload(reply_event_data), get_reply=False
|
||||||
|
)
|
||||||
|
|
||||||
|
reply_seg = Reply(
|
||||||
|
id=abm_reply.message_id,
|
||||||
|
chain=abm_reply.message,
|
||||||
|
sender_id=abm_reply.sender.user_id,
|
||||||
|
sender_nickname=abm_reply.sender.nickname,
|
||||||
|
time=abm_reply.timestamp,
|
||||||
|
message_str=abm_reply.message_str,
|
||||||
|
text=abm_reply.message_str, # for compatibility
|
||||||
|
qq=abm_reply.sender.user_id, # for compatibility
|
||||||
|
)
|
||||||
|
|
||||||
|
abm.message.append(reply_seg)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"获取引用消息失败: {e}。")
|
||||||
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
|
abm.message.append(a)
|
||||||
|
else:
|
||||||
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
|
abm.message.append(a)
|
||||||
|
|
||||||
abm.timestamp = int(time.time())
|
abm.timestamp = int(time.time())
|
||||||
abm.message_str = message_str
|
abm.message_str = message_str
|
||||||
abm.raw_message = event
|
abm.raw_message = event
|
||||||
|
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
def run(self) -> Awaitable[Any]:
|
def run(self) -> Awaitable[Any]:
|
||||||
if not self.host or not self.port:
|
if not self.host or not self.port:
|
||||||
return
|
logger.warning(
|
||||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
"aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199"
|
||||||
@self.bot.on_message('group')
|
)
|
||||||
async def group(event: Event):
|
self.host = "0.0.0.0"
|
||||||
abm = await self.convert_message(event)
|
self.port = 6199
|
||||||
if abm:
|
|
||||||
await self.handle_msg(abm)
|
|
||||||
|
|
||||||
@self.bot.on_message('private')
|
coro = self.bot.run_task(
|
||||||
async def private(event: Event):
|
host=self.host,
|
||||||
abm = await self.convert_message(event)
|
port=int(self.port),
|
||||||
if abm:
|
shutdown_trigger=self.shutdown_trigger_placeholder,
|
||||||
await self.handle_msg(abm)
|
)
|
||||||
|
|
||||||
@self.bot.on_websocket_connection
|
|
||||||
def on_websocket_connection(_):
|
|
||||||
logger.info("aiocqhttp 适配器已连接。")
|
|
||||||
|
|
||||||
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
|
|
||||||
|
|
||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
logging.getLogger('aiocqhttp').setLevel(logging.ERROR)
|
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
return coro
|
||||||
|
|
||||||
return bot
|
async def terminate(self):
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
async def shutdown_trigger_placeholder(self):
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
logger.info("aiocqhttp 适配器已被优雅地关闭")
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return self.metadata
|
return self.metadata
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
|
||||||
while not self._event_queue.closed:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("aiocqhttp 适配器已关闭。")
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
|
||||||
message_event = AiocqhttpMessageEvent(
|
message_event = AiocqhttpMessageEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
message_obj=message,
|
message_obj=message,
|
||||||
platform_meta=self.meta(),
|
platform_meta=self.meta(),
|
||||||
session_id=message.session_id,
|
session_id=message.session_id,
|
||||||
bot=self.bot
|
bot=self.bot,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
def get_client(self) -> CQHttp:
|
||||||
|
return self.bot
|
||||||
|
|||||||
228
astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Normal file
228
astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import aiohttp
|
||||||
|
import dingtalk_stream
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Image, Plain, At
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from .dingtalk_event import DingtalkMessageEvent
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from astrbot import logger
|
||||||
|
from dingtalk_stream import AckMessage
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
|
|
||||||
|
|
||||||
|
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||||
|
async def process(self, event: dingtalk_stream.EventMessage):
|
||||||
|
print(
|
||||||
|
"2",
|
||||||
|
event.headers.event_type,
|
||||||
|
event.headers.event_id,
|
||||||
|
event.headers.event_born_time,
|
||||||
|
event.data,
|
||||||
|
)
|
||||||
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
|
||||||
|
class DingtalkPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
|
||||||
|
self.config = platform_config
|
||||||
|
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
self.client_id = platform_config["client_id"]
|
||||||
|
self.client_secret = platform_config["client_secret"]
|
||||||
|
|
||||||
|
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||||
|
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||||
|
logger.debug(f"dingtalk: {message.data}")
|
||||||
|
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||||
|
abm = await self.convert_msg(im)
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
|
self.client = AstrCallbackClient()
|
||||||
|
|
||||||
|
credential = dingtalk_stream.Credential(self.client_id, self.client_secret)
|
||||||
|
client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger)
|
||||||
|
client.register_all_event_handler(MyEventHandler())
|
||||||
|
client.register_callback_handler(
|
||||||
|
dingtalk_stream.ChatbotMessage.TOPIC, self.client
|
||||||
|
)
|
||||||
|
self.client_ = client # 用于 websockets 的 client
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="dingtalk",
|
||||||
|
description="钉钉机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def convert_msg(
|
||||||
|
self, message: dingtalk_stream.ChatbotMessage
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.message = []
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.timestamp = int(message.create_at / 1000)
|
||||||
|
abm.type = (
|
||||||
|
MessageType.GROUP_MESSAGE
|
||||||
|
if message.conversation_type == "2"
|
||||||
|
else MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
user_id=message.sender_id, nickname=message.sender_nick
|
||||||
|
)
|
||||||
|
abm.self_id = message.chatbot_user_id
|
||||||
|
abm.message_id = message.message_id
|
||||||
|
abm.raw_message = message
|
||||||
|
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
if message.is_in_at_list:
|
||||||
|
abm.message.append(At(qq=abm.self_id))
|
||||||
|
abm.group_id = message.conversation_id
|
||||||
|
if self.unique_session:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.group_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
|
message_type: str = message.message_type
|
||||||
|
match message_type:
|
||||||
|
case "text":
|
||||||
|
abm.message_str = message.text.content.strip()
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
case "richText":
|
||||||
|
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||||
|
contents: list[dict] = rtc.rich_text_list
|
||||||
|
for content in contents:
|
||||||
|
plains = ""
|
||||||
|
if "text" in content:
|
||||||
|
plains += content["text"]
|
||||||
|
abm.message.append(Plain(plains))
|
||||||
|
elif "type" in content and content["type"] == "picture":
|
||||||
|
f_path = await self.download_ding_file(
|
||||||
|
content["downloadCode"],
|
||||||
|
message.robot_code,
|
||||||
|
"jpg",
|
||||||
|
)
|
||||||
|
abm.message.append(Image.fromFileSystem(f_path))
|
||||||
|
case "audio":
|
||||||
|
pass
|
||||||
|
|
||||||
|
return abm # 别忘了返回转换后的消息对象
|
||||||
|
|
||||||
|
async def download_ding_file(
|
||||||
|
self, download_code: str, robot_code: str, ext: str
|
||||||
|
) -> str:
|
||||||
|
"""下载钉钉文件
|
||||||
|
|
||||||
|
:param access_token: 钉钉机器人的 access_token
|
||||||
|
:param download_code: 下载码
|
||||||
|
:param robot_code: 机器人码
|
||||||
|
:param ext: 文件后缀
|
||||||
|
:return: 文件路径
|
||||||
|
"""
|
||||||
|
access_token = await self.get_access_token()
|
||||||
|
headers = {
|
||||||
|
"x-acs-dingtalk-access-token": access_token,
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"downloadCode": download_code,
|
||||||
|
"robotCode": robot_code,
|
||||||
|
}
|
||||||
|
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"下载钉钉文件失败: {resp.status}, {await resp.text()}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
resp_data = await resp.json()
|
||||||
|
download_url = resp_data["data"]["downloadUrl"]
|
||||||
|
await download_file(download_url, f_path)
|
||||||
|
return f_path
|
||||||
|
|
||||||
|
async def get_access_token(self) -> str:
|
||||||
|
payload = {
|
||||||
|
"appKey": self.client_id,
|
||||||
|
"appSecret": self.client_secret,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return (await resp.json())["data"]["accessToken"]
|
||||||
|
|
||||||
|
async def handle_msg(self, abm: AstrBotMessage):
|
||||||
|
event = DingtalkMessageEvent(
|
||||||
|
message_str=abm.message_str,
|
||||||
|
message_obj=abm,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=abm.session_id,
|
||||||
|
client=self.client,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_queue.put_nowait(event)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
# await self.client_.start()
|
||||||
|
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
|
||||||
|
def start_client(loop: asyncio.AbstractEventLoop):
|
||||||
|
try:
|
||||||
|
self._shutdown_event = threading.Event()
|
||||||
|
task = loop.create_task(self.client_.start())
|
||||||
|
self._shutdown_event.wait()
|
||||||
|
if task.done():
|
||||||
|
task.result()
|
||||||
|
except Exception as e:
|
||||||
|
if "Graceful shutdown" in str(e):
|
||||||
|
logger.info("钉钉适配器已被优雅地关闭")
|
||||||
|
return
|
||||||
|
logger.error(f"钉钉机器人启动失败: {e}")
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, start_client, loop)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
def monkey_patch_close():
|
||||||
|
raise Exception("Graceful shutdown")
|
||||||
|
|
||||||
|
self.client_.open_connection = monkey_patch_close
|
||||||
|
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
return self.client
|
||||||
75
astrbot/core/platform/sources/dingtalk/dingtalk_event.py
Normal file
75
astrbot/core/platform/sources/dingtalk/dingtalk_event.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import asyncio
|
||||||
|
import dingtalk_stream
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
|
class DingtalkMessageEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str,
|
||||||
|
message_obj,
|
||||||
|
platform_meta,
|
||||||
|
session_id,
|
||||||
|
client: dingtalk_stream.ChatbotHandler,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def send_with_client(
|
||||||
|
self, client: dingtalk_stream.ChatbotHandler, message: MessageChain
|
||||||
|
):
|
||||||
|
for segment in message.chain:
|
||||||
|
if isinstance(segment, Comp.Plain):
|
||||||
|
segment.text = segment.text.strip()
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
client.reply_markdown,
|
||||||
|
"AstrBot",
|
||||||
|
segment.text,
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
)
|
||||||
|
elif isinstance(segment, Comp.Image):
|
||||||
|
markdown_str = ""
|
||||||
|
if segment.file and segment.file.startswith("file:///"):
|
||||||
|
logger.warning(
|
||||||
|
"dingtalk only support url image, not: " + segment.file
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif segment.file and segment.file.startswith("http"):
|
||||||
|
markdown_str += f"\n\n"
|
||||||
|
elif segment.file and segment.file.startswith("base64://"):
|
||||||
|
logger.warning("dingtalk only support url image, not base64")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"dingtalk only support url image, not: " + segment.file
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ret = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
client.reply_markdown,
|
||||||
|
"😄",
|
||||||
|
markdown_str,
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
)
|
||||||
|
logger.debug(f"send image: {ret}")
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
await self.send_with_client(self.client, message)
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator)
|
||||||
@@ -1,29 +1,49 @@
|
|||||||
import threading
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
|
||||||
import quart
|
|
||||||
import base64
|
import base64
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import anyio
|
||||||
|
import quart
|
||||||
|
|
||||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
|
||||||
from astrbot.api.message_components import Plain, Image, At, Record
|
|
||||||
from astrbot.api import logger, sp
|
from astrbot.api import logger, sp
|
||||||
from .downloader import GeweDownloader
|
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||||
|
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from .downloader import GeweDownloader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .xml_data_parser import GeweDataParser
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
logger.warning(
|
||||||
|
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleGewechatClient():
|
class SimpleGewechatClient:
|
||||||
'''针对 Gewechat 的简单实现。
|
"""针对 Gewechat 的简单实现。
|
||||||
|
|
||||||
@author: Soulter
|
@author: Soulter
|
||||||
@website: https://github.com/Soulter
|
@website: https://github.com/Soulter
|
||||||
'''
|
"""
|
||||||
def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
nickname: str,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
event_queue: asyncio.Queue,
|
||||||
|
):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
if self.base_url.endswith('/'):
|
if self.base_url.endswith("/"):
|
||||||
self.base_url = self.base_url[:-1]
|
self.base_url = self.base_url[:-1]
|
||||||
|
|
||||||
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
|
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
|
||||||
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
|
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
|
||||||
|
|
||||||
self.base_url += "/v2/api"
|
self.base_url += "/v2/api"
|
||||||
|
|
||||||
@@ -39,8 +59,14 @@ class SimpleGewechatClient():
|
|||||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||||
|
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
self.server.add_url_rule(
|
||||||
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
|
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||||
|
)
|
||||||
|
self.server.add_url_rule(
|
||||||
|
"/astrbot-gewechat/file/<file_id>",
|
||||||
|
view_func=self._handle_file,
|
||||||
|
methods=["GET"],
|
||||||
|
)
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
@@ -51,65 +77,138 @@ class SimpleGewechatClient():
|
|||||||
|
|
||||||
self.multimedia_downloader = None
|
self.multimedia_downloader = None
|
||||||
|
|
||||||
|
self.userrealnames = {}
|
||||||
|
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def get_token_id(self):
|
async def get_token_id(self):
|
||||||
|
"""获取 Gewechat Token。"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
self.token = json_blob['data']
|
self.token = json_blob["data"]
|
||||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||||
self.headers = {
|
self.headers = {"X-GEWE-TOKEN": self.token}
|
||||||
"X-GEWE-TOKEN": self.token
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||||
type_name = data['TypeName']
|
if "TypeName" in data:
|
||||||
|
type_name = data["TypeName"]
|
||||||
|
elif "type_name" in data:
|
||||||
|
type_name = data["type_name"]
|
||||||
|
else:
|
||||||
|
raise Exception("无法识别的消息类型")
|
||||||
|
|
||||||
|
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||||
|
if type_name == "ModContacts":
|
||||||
|
logger.info("gewechat下发:ModContacts消息通知。")
|
||||||
|
return
|
||||||
|
if type_name == "DelContacts":
|
||||||
|
logger.info("gewechat下发:DelContacts消息通知。")
|
||||||
|
return
|
||||||
|
|
||||||
if type_name == "Offline":
|
if type_name == "Offline":
|
||||||
logger.critical("收到 gewechat 下线通知。")
|
logger.critical("收到 gewechat 下线通知。")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
d = None
|
||||||
|
if "Data" in data:
|
||||||
|
d = data["Data"]
|
||||||
|
elif "data" in data:
|
||||||
|
d = data["data"]
|
||||||
|
|
||||||
|
if not d:
|
||||||
|
logger.warning(f"消息不含 data 字段: {data}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if "CreateTime" in d:
|
||||||
|
# 得到系统 UTF+8 的 ts
|
||||||
|
tz_offset = datetime.timedelta(hours=8)
|
||||||
|
tz = datetime.timezone(tz_offset)
|
||||||
|
ts = datetime.datetime.now(tz).timestamp()
|
||||||
|
create_time = d["CreateTime"]
|
||||||
|
if create_time < ts - 30:
|
||||||
|
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||||
|
return
|
||||||
|
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
d = data['Data']
|
|
||||||
|
|
||||||
from_user_name = d['FromUserName']['string'] # 消息来源
|
from_user_name = d["FromUserName"]["string"] # 消息来源
|
||||||
d['to_wxid'] = from_user_name # 用于发信息
|
d["to_wxid"] = from_user_name # 用于发信息
|
||||||
|
|
||||||
abm.message_id = str(d.get('MsgId'))
|
abm.message_id = str(d.get("MsgId"))
|
||||||
abm.session_id = from_user_name
|
abm.session_id = from_user_name
|
||||||
abm.self_id = data['Wxid'] # 机器人的 wxid
|
abm.self_id = data["Wxid"] # 机器人的 wxid
|
||||||
|
|
||||||
user_id = "" # 发送人 wxid
|
user_id = "" # 发送人 wxid
|
||||||
content = d['Content']['string'] # 消息内容
|
content = d["Content"]["string"] # 消息内容
|
||||||
|
|
||||||
at_me = False
|
at_me = False
|
||||||
if "@chatroom" in from_user_name:
|
if "@chatroom" in from_user_name:
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
_t = content.split(':\n')
|
_t = content.split(":\n")
|
||||||
user_id = _t[0]
|
user_id = _t[0]
|
||||||
content = _t[1]
|
content = _t[1]
|
||||||
if '\u2005' in content:
|
if "\u2005" in content:
|
||||||
# at
|
# at
|
||||||
content = content.split('\u2005')[1]
|
# content = content.split('\u2005')[1]
|
||||||
|
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||||
abm.group_id = from_user_name
|
abm.group_id = from_user_name
|
||||||
# at
|
# at
|
||||||
msg_source = d['MsgSource']
|
msg_source = d["MsgSource"]
|
||||||
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
|
if (
|
||||||
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
|
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||||
|
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||||
|
):
|
||||||
|
at_me = True
|
||||||
|
if "在群聊中@了你" in d.get("PushContent", ""):
|
||||||
at_me = True
|
at_me = True
|
||||||
else:
|
else:
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
user_id = from_user_name
|
user_id = from_user_name
|
||||||
|
|
||||||
|
# 检查消息是否由自己发送,若是则忽略
|
||||||
|
if user_id == abm.self_id:
|
||||||
|
logger.info("忽略自己发送的消息")
|
||||||
|
return None
|
||||||
|
|
||||||
abm.message = []
|
abm.message = []
|
||||||
if at_me:
|
if at_me:
|
||||||
abm.message.insert(0, At(qq=abm.self_id))
|
abm.message.insert(0, At(qq=abm.self_id))
|
||||||
|
|
||||||
user_real_name = d['PushContent'].split(' : ')[0] \
|
# 解析用户真实名字
|
||||||
.replace('在群聊中@了你', '') \
|
user_real_name = "unknown"
|
||||||
.replace('在群聊中发了一段语音', '') # 真实昵称
|
if abm.group_id:
|
||||||
|
if (
|
||||||
|
abm.group_id not in self.userrealnames
|
||||||
|
or user_id not in self.userrealnames[abm.group_id]
|
||||||
|
):
|
||||||
|
# 获取群成员列表,并且缓存
|
||||||
|
if abm.group_id not in self.userrealnames:
|
||||||
|
self.userrealnames[abm.group_id] = {}
|
||||||
|
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||||
|
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||||
|
if member_list and "memberList" in member_list:
|
||||||
|
for member in member_list["memberList"]:
|
||||||
|
self.userrealnames[abm.group_id][member["wxid"]] = member[
|
||||||
|
"nickName"
|
||||||
|
]
|
||||||
|
if user_id in self.userrealnames[abm.group_id]:
|
||||||
|
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||||
|
else:
|
||||||
|
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||||
|
else:
|
||||||
|
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0]
|
||||||
|
|
||||||
abm.sender = MessageMember(user_id, user_real_name)
|
abm.sender = MessageMember(user_id, user_real_name)
|
||||||
abm.raw_message = d
|
abm.raw_message = d
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
|
|
||||||
|
if user_id == "weixin":
|
||||||
|
# 忽略微信团队消息
|
||||||
|
return
|
||||||
|
|
||||||
# 不同消息类型
|
# 不同消息类型
|
||||||
match d['MsgType']:
|
match d["MsgType"]:
|
||||||
case 1:
|
case 1:
|
||||||
# 文本消息
|
# 文本消息
|
||||||
abm.message.append(Plain(content))
|
abm.message.append(Plain(content))
|
||||||
@@ -117,8 +216,7 @@ class SimpleGewechatClient():
|
|||||||
case 3:
|
case 3:
|
||||||
# 图片消息
|
# 图片消息
|
||||||
file_url = await self.multimedia_downloader.download_image(
|
file_url = await self.multimedia_downloader.download_image(
|
||||||
self.appid,
|
self.appid, content
|
||||||
content
|
|
||||||
)
|
)
|
||||||
logger.debug(f"下载图片: {file_url}")
|
logger.debug(f"下载图片: {file_url}")
|
||||||
file_path = await download_image_by_url(file_url)
|
file_path = await download_image_by_url(file_url)
|
||||||
@@ -126,34 +224,63 @@ class SimpleGewechatClient():
|
|||||||
|
|
||||||
case 34:
|
case 34:
|
||||||
# 语音消息
|
# 语音消息
|
||||||
# data = await self.multimedia_downloader.download_voice(
|
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||||
# self.appid,
|
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||||
# content,
|
|
||||||
# abm.message_id
|
|
||||||
# )
|
|
||||||
# print(data)
|
|
||||||
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
|
|
||||||
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
|
|
||||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(voice_data)
|
async with await anyio.open_file(file_path, "wb") as f:
|
||||||
|
await f.write(voice_data)
|
||||||
abm.message.append(Record(file=file_path, url=file_path))
|
abm.message.append(Record(file=file_path, url=file_path))
|
||||||
|
|
||||||
case _:
|
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||||
logger.error(f"未实现的消息类型: {d['MsgType']}")
|
case 37: # 好友申请
|
||||||
return
|
logger.info("消息类型(37):好友申请")
|
||||||
|
case 42: # 名片
|
||||||
|
logger.info("消息类型(42):名片")
|
||||||
|
case 43: # 视频
|
||||||
|
video = Video(file="", cover=content)
|
||||||
|
abm.message.append(video)
|
||||||
|
case 47: # emoji
|
||||||
|
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||||
|
emoji = data_parser.parse_emoji()
|
||||||
|
abm.message.append(emoji)
|
||||||
|
case 48: # 地理位置
|
||||||
|
logger.info("消息类型(48):地理位置")
|
||||||
|
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||||
|
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||||
|
abm_data = data_parser.parse_mutil_49()
|
||||||
|
if abm_data:
|
||||||
|
abm.message.append(abm_data)
|
||||||
|
case 51: # 帐号消息同步?
|
||||||
|
logger.info("消息类型(51):帐号消息同步?")
|
||||||
|
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||||
|
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||||
|
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||||
|
logger.info(
|
||||||
|
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"abm: {abm}")
|
case _:
|
||||||
|
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||||
|
abm.raw_message = d
|
||||||
|
|
||||||
|
logger.debug(f"abm: {abm}")
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
async def callback(self):
|
async def _callback(self):
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
logger.debug(f"收到 gewechat 回调: {data}")
|
logger.debug(f"收到 gewechat 回调: {data}")
|
||||||
|
|
||||||
if data.get('testMsg', None):
|
if data.get("testMsg", None):
|
||||||
return quart.jsonify({"r": "AstrBot ACK"})
|
return quart.jsonify({"r": "AstrBot ACK"})
|
||||||
|
|
||||||
|
abm = None
|
||||||
|
try:
|
||||||
abm = await self._convert(data)
|
abm = await self._convert(data)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.warning(
|
||||||
|
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。"
|
||||||
|
)
|
||||||
|
|
||||||
if abm:
|
if abm:
|
||||||
coro = getattr(self, "on_event_received")
|
coro = getattr(self, "on_event_received")
|
||||||
@@ -162,7 +289,7 @@ class SimpleGewechatClient():
|
|||||||
|
|
||||||
return quart.jsonify({"r": "AstrBot ACK"})
|
return quart.jsonify({"r": "AstrBot ACK"})
|
||||||
|
|
||||||
async def handle_file(self, file_id):
|
async def _handle_file(self, file_id):
|
||||||
file_path = f"data/temp/{file_id}"
|
file_path = f"data/temp/{file_id}"
|
||||||
return await quart.send_file(file_path)
|
return await quart.send_file(file_path)
|
||||||
|
|
||||||
@@ -173,44 +300,40 @@ class SimpleGewechatClient():
|
|||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/tools/setCallback",
|
f"{self.base_url}/tools/setCallback",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json={
|
json={"token": self.token, "callbackUrl": self.callback_url},
|
||||||
"token": self.token,
|
|
||||||
"callbackUrl": self.callback_url
|
|
||||||
}
|
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.info(f"设置回调结果: {json_blob}")
|
logger.info(f"设置回调结果: {json_blob}")
|
||||||
if json_blob['ret'] != 200:
|
if json_blob["ret"] != 200:
|
||||||
raise Exception(f"设置回调失败: {json_blob}")
|
raise Exception(f"设置回调失败: {json_blob}")
|
||||||
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
|
logger.info(
|
||||||
|
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
|
||||||
|
)
|
||||||
|
|
||||||
async def start_polling(self):
|
async def start_polling(self):
|
||||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host=self.host,
|
host="0.0.0.0",
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
while not self.event_queue.closed:
|
await self.shutdown_event.wait()
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("gewechat 适配器已关闭。")
|
|
||||||
|
|
||||||
async def check_online(self, appid: str):
|
async def check_online(self, appid: str):
|
||||||
# /login/checkOnline
|
"""检查 APPID 对应的设备是否在线。"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/login/checkOnline",
|
f"{self.base_url}/login/checkOnline",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json={
|
json={"appId": appid},
|
||||||
"appId": appid
|
|
||||||
}
|
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
return json_blob['data']
|
return json_blob["data"]
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
|
"""登出 gewechat。"""
|
||||||
if self.appid:
|
if self.appid:
|
||||||
online = await self.check_online(self.appid)
|
online = await self.check_online(self.appid)
|
||||||
if online:
|
if online:
|
||||||
@@ -218,65 +341,108 @@ class SimpleGewechatClient():
|
|||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/login/logout",
|
f"{self.base_url}/login/logout",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json={
|
json={"appId": self.appid},
|
||||||
"appId": self.appid
|
|
||||||
}
|
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.info(f"登出结果: {json_blob}")
|
logger.info(f"登出结果: {json_blob}")
|
||||||
|
|
||||||
async def login(self):
|
async def login(self):
|
||||||
|
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||||
if self.token is None:
|
if self.token is None:
|
||||||
await self.get_token_id()
|
await self.get_token_id()
|
||||||
|
|
||||||
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
|
self.multimedia_downloader = GeweDownloader(
|
||||||
|
self.base_url, self.download_base_url, self.token
|
||||||
|
)
|
||||||
|
|
||||||
if self.appid:
|
if self.appid:
|
||||||
|
try:
|
||||||
online = await self.check_online(self.appid)
|
online = await self.check_online(self.appid)
|
||||||
if online:
|
if online:
|
||||||
logger.info(f"APPID: {self.appid} 已在线")
|
logger.info(f"APPID: {self.appid} 已在线")
|
||||||
return
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查在线状态失败: {e}")
|
||||||
|
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||||
|
self.appid = None
|
||||||
|
|
||||||
payload = {
|
payload = {"appId": self.appid}
|
||||||
"appId": self.appid
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.appid:
|
if self.appid:
|
||||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||||
|
|
||||||
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/login/getLoginQrCode",
|
f"{self.base_url}/login/getLoginQrCode",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=payload
|
json=payload,
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
if json_blob['ret'] != 200:
|
if json_blob["ret"] != 200:
|
||||||
|
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||||
|
if "设备不存在" in error_msg:
|
||||||
|
logger.error(
|
||||||
|
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||||
|
)
|
||||||
|
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||||
|
self.appid = None
|
||||||
|
return await self.login()
|
||||||
|
else:
|
||||||
raise Exception(f"获取二维码失败: {json_blob}")
|
raise Exception(f"获取二维码失败: {json_blob}")
|
||||||
qr_data = json_blob['data']['qrData']
|
qr_data = json_blob["data"]["qrData"]
|
||||||
qr_uuid = json_blob['data']['uuid']
|
qr_uuid = json_blob["data"]["uuid"]
|
||||||
appid = json_blob['data']['appId']
|
appid = json_blob["data"]["appId"]
|
||||||
logger.info(f"APPID: {appid}")
|
logger.info(f"APPID: {appid}")
|
||||||
logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}")
|
logger.warning(
|
||||||
|
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
# 执行登录
|
# 执行登录
|
||||||
retry_cnt = 64
|
retry_cnt = 64
|
||||||
payload.update({
|
payload.update({"uuid": qr_uuid, "appId": appid})
|
||||||
"uuid": qr_uuid,
|
|
||||||
"appId": appid
|
|
||||||
})
|
|
||||||
while retry_cnt > 0:
|
while retry_cnt > 0:
|
||||||
retry_cnt -= 1
|
retry_cnt -= 1
|
||||||
|
|
||||||
|
# 需要验证码
|
||||||
|
if os.path.exists("data/temp/gewe_code"):
|
||||||
|
with open("data/temp/gewe_code", "r") as f:
|
||||||
|
code = f.read().strip()
|
||||||
|
if not code:
|
||||||
|
logger.warning(
|
||||||
|
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
continue
|
||||||
|
payload["captchCode"] = code
|
||||||
|
logger.info(f"使用验证码: {code}")
|
||||||
|
try:
|
||||||
|
os.remove("data/temp/gewe_code")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/login/checkLogin",
|
f"{self.base_url}/login/checkLogin",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=payload
|
json=payload,
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.info(f"检查登录状态: {json_blob}")
|
logger.info(f"检查登录状态: {json_blob}")
|
||||||
status = json_blob['data']['status']
|
|
||||||
nickname = json_blob['data'].get('nickName', '')
|
ret = json_blob["ret"]
|
||||||
|
msg = ""
|
||||||
|
if json_blob["data"] and "msg" in json_blob["data"]:
|
||||||
|
msg = json_blob["data"]["msg"]
|
||||||
|
if ret == 500 and "安全验证码" in msg:
|
||||||
|
logger.warning(
|
||||||
|
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
status = json_blob["data"]["status"]
|
||||||
|
nickname = json_blob["data"].get("nickName", "")
|
||||||
if status == 1:
|
if status == 1:
|
||||||
logger.info(f"等待确认...{nickname}")
|
logger.info(f"等待确认...{nickname}")
|
||||||
elif status == 2:
|
elif status == 2:
|
||||||
@@ -289,27 +455,52 @@ class SimpleGewechatClient():
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
if appid:
|
if appid:
|
||||||
sp.put(f"gewechat-appid-{nickname}", appid)
|
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||||
self.appid = appid
|
self.appid = appid
|
||||||
logger.info(f"已保存 APPID: {appid}")
|
logger.info(f"已保存 APPID: {appid}")
|
||||||
|
|
||||||
async def post_text(self, to_wxid, content: str):
|
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||||
|
"""获取群成员列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||||
|
"""
|
||||||
|
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomMemberList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
return json_blob["data"]
|
||||||
|
|
||||||
|
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||||
|
"""发送纯文本消息"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
if ats:
|
||||||
|
payload["ats"] = ats
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/message/postText",
|
f"{self.base_url}/message/postText", headers=self.headers, json=payload
|
||||||
headers=self.headers,
|
|
||||||
json=payload
|
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送消息结果: {json_blob}")
|
logger.debug(f"发送消息结果: {json_blob}")
|
||||||
|
|
||||||
async def post_image(self, to_wxid, image_url: str):
|
async def post_image(self, to_wxid, image_url: str):
|
||||||
|
"""发送图片消息"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -318,28 +509,246 @@ class SimpleGewechatClient():
|
|||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/message/postImage",
|
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
|
||||||
headers=self.headers,
|
|
||||||
json=payload
|
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送图片结果: {json_blob}")
|
logger.debug(f"发送图片结果: {json_blob}")
|
||||||
|
|
||||||
|
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||||
|
"""发送emoji消息"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"emojiMd5": emoji_md5,
|
||||||
|
"emojiSize": emoji_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||||
|
try:
|
||||||
|
if emoji_md5 != "" and emoji_size != "":
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postEmoji",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.info(
|
||||||
|
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.post_image(to_wxid, cdnurl)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
async def post_video(
|
||||||
|
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"videoUrl": video_url,
|
||||||
|
"thumbUrl": thumb_url,
|
||||||
|
"videoDuration": video_duration,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"发送视频结果: {json_blob}")
|
||||||
|
|
||||||
|
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||||
|
"""转发视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_wxid (str): 发送给谁
|
||||||
|
cnd_xml (str): 视频消息的cdn信息
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"xml": cnd_xml,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/forwardVideo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"转发视频结果: {json_blob}")
|
||||||
|
|
||||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||||
|
"""发送语音信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_url (str): 语音文件的网络链接
|
||||||
|
voice_duration (int): 语音时长,毫秒
|
||||||
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
"voiceUrl": voice_url,
|
"voiceUrl": voice_url,
|
||||||
"voiceDuration": voice_duration
|
"voiceDuration": voice_duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"发送语音: {payload}")
|
logger.debug(f"发送语音: {payload}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/message/postVoice",
|
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||||
headers=self.headers,
|
|
||||||
json=payload
|
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送语音结果: {json_blob}")
|
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||||
|
|
||||||
|
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||||
|
"""发送文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_wxid (string): 微信ID
|
||||||
|
file_url (str): 文件的网络链接
|
||||||
|
file_name (str): 文件名
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"fileUrl": file_url,
|
||||||
|
"fileName": file_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"发送文件结果: {json_blob}")
|
||||||
|
|
||||||
|
async def add_friend(self, v3: str, v4: str, content: str):
|
||||||
|
"""申请添加好友"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"scene": 3,
|
||||||
|
"content": content,
|
||||||
|
"v4": v4,
|
||||||
|
"v3": v3,
|
||||||
|
"option": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/addContacts",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomInfo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_group_member(self, group_id: str):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomMemberList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def accept_group_invite(self, url: str):
|
||||||
|
"""同意进群"""
|
||||||
|
payload = {"appId": self.appid, "url": url}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/agreeJoinRoom",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def add_group_member_to_friend(
|
||||||
|
self, group_id: str, to_wxid: str, content: str
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
"content": content,
|
||||||
|
"memberWxid": to_wxid,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_user_or_group_info(self, *ids):
|
||||||
|
"""
|
||||||
|
获取用户或群组信息。
|
||||||
|
|
||||||
|
:param ids: 可变数量的 wxid 参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
wxids_str = list(ids)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/getDetailInfo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_contacts_list(self):
|
||||||
|
"""
|
||||||
|
获取通讯录列表
|
||||||
|
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||||
|
"""
|
||||||
|
payload = {"appId": self.appid}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/fetchContactsList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|||||||
@@ -2,50 +2,54 @@ from astrbot import logger
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class GeweDownloader():
|
|
||||||
|
class GeweDownloader:
|
||||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.download_base_url = download_base_url
|
self.download_base_url = download_base_url
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"X-GEWE-TOKEN": token
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{baseurl}{route}",
|
f"{baseurl}{route}", headers=self.headers, json=payload
|
||||||
headers=self.headers,
|
|
||||||
json=payload
|
|
||||||
) as resp:
|
) as resp:
|
||||||
return await resp.read()
|
return await resp.read()
|
||||||
|
|
||||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||||
payload = {
|
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
|
||||||
"appId": appid,
|
|
||||||
"xml": xml,
|
|
||||||
"msgId": msg_id
|
|
||||||
}
|
|
||||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||||
|
|
||||||
async def download_image(self, appid: str, xml: str) -> str:
|
async def download_image(self, appid: str, xml: str) -> str:
|
||||||
'''返回一个可下载的 URL'''
|
"""返回一个可下载的 URL"""
|
||||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
try:
|
try:
|
||||||
payload = {
|
payload = {"appId": appid, "xml": xml, "type": choice}
|
||||||
"appId": appid,
|
data = await self._post_json(
|
||||||
"xml": xml,
|
self.base_url, "/message/downloadImage", payload
|
||||||
"type": choice
|
)
|
||||||
}
|
|
||||||
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
|
|
||||||
json_blob = json.loads(data)
|
json_blob = json.loads(data)
|
||||||
if 'fileUrl' in json_blob['data']:
|
if "fileUrl" in json_blob["data"]:
|
||||||
return self.download_base_url + json_blob['data']['fileUrl']
|
return self.download_base_url + json_blob["data"]["fileUrl"]
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"gewe download image: {e}")
|
logger.error(f"gewe download image: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise Exception("无法下载图片")
|
raise Exception("无法下载图片")
|
||||||
|
|
||||||
|
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||||
|
"""下载emoji"""
|
||||||
|
try:
|
||||||
|
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||||
|
|
||||||
|
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||||
|
data = await self._post_json(
|
||||||
|
self.base_url, "/message/downloadEmojiMd5", payload
|
||||||
|
)
|
||||||
|
json_blob = json.loads(data)
|
||||||
|
return json_blob
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"gewe download emoji: {e}")
|
||||||
|
|||||||
@@ -1,24 +1,38 @@
|
|||||||
import wave
|
import wave
|
||||||
import uuid
|
import uuid
|
||||||
|
import traceback
|
||||||
import os
|
import os
|
||||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
|
||||||
|
from astrbot.core.utils.io import save_temp_img, download_file
|
||||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||||
from astrbot.api.message_components import Plain, Image, Record
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
Record,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Video,
|
||||||
|
WechatEmoji as Emoji,
|
||||||
|
)
|
||||||
from .client import SimpleGewechatClient
|
from .client import SimpleGewechatClient
|
||||||
|
|
||||||
|
|
||||||
def get_wav_duration(file_path):
|
def get_wav_duration(file_path):
|
||||||
with wave.open(file_path, 'rb') as wav_file:
|
with wave.open(file_path, "rb") as wav_file:
|
||||||
file_size = os.path.getsize(file_path)
|
file_size = os.path.getsize(file_path)
|
||||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||||
if n_frames == 2147483647:
|
if n_frames == 2147483647:
|
||||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||||
|
elif n_frames == 0:
|
||||||
|
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||||
else:
|
else:
|
||||||
duration = n_frames / float(framerate)
|
duration = n_frames / float(framerate)
|
||||||
return duration
|
return duration
|
||||||
|
|
||||||
|
|
||||||
class GewechatPlatformEvent(AstrMessageEvent):
|
class GewechatPlatformEvent(AstrMessageEvent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -26,77 +40,192 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
message_obj: AstrBotMessage,
|
message_obj: AstrBotMessage,
|
||||||
platform_meta: PlatformMetadata,
|
platform_meta: PlatformMetadata,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
client: SimpleGewechatClient
|
client: SimpleGewechatClient,
|
||||||
):
|
):
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def send_with_client(message: MessageChain, user_name: str):
|
async def send_with_client(
|
||||||
pass
|
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
||||||
|
):
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
to_wxid = self.message_obj.raw_message.get('to_wxid', None)
|
|
||||||
|
|
||||||
if not to_wxid:
|
if not to_wxid:
|
||||||
logger.error("无法获取到 to_wxid。")
|
logger.error("无法获取到 to_wxid。")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 检查@
|
||||||
|
ats = []
|
||||||
|
ats_names = []
|
||||||
|
for comp in message.chain:
|
||||||
|
if isinstance(comp, At):
|
||||||
|
ats.append(comp.qq)
|
||||||
|
ats_names.append(comp.name)
|
||||||
|
has_at = False
|
||||||
|
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
await self.client.post_text(to_wxid, comp.text)
|
text = comp.text
|
||||||
elif isinstance(comp, Image):
|
payload = {
|
||||||
img_url = comp.file
|
"to_wxid": to_wxid,
|
||||||
img_path = ""
|
"content": text,
|
||||||
if img_url.startswith("file:///"):
|
}
|
||||||
img_path = img_url[8:]
|
if not has_at and ats:
|
||||||
elif comp.file and comp.file.startswith("http"):
|
ats = f"{','.join(ats)}"
|
||||||
img_path = await download_image_by_url(comp.file)
|
ats_names = f"@{' @'.join(ats_names)}"
|
||||||
else:
|
text = f"{ats_names} {text}"
|
||||||
img_path = img_url
|
payload["content"] = text
|
||||||
|
payload["ats"] = ats
|
||||||
|
has_at = True
|
||||||
|
await client.post_text(**payload)
|
||||||
|
|
||||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
elif isinstance(comp, Image):
|
||||||
temp_directory = os.path.abspath('data/temp')
|
img_path = await comp.convert_to_file_path()
|
||||||
img_path = os.path.abspath(img_path)
|
|
||||||
|
# 检查 record_path 是否在 data/temp 目录中
|
||||||
|
temp_directory = os.path.abspath("data/temp")
|
||||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||||
with open(img_path, "rb") as f:
|
with open(img_path, "rb") as f:
|
||||||
img_path = save_temp_img(f.read())
|
img_path = save_temp_img(f.read())
|
||||||
|
|
||||||
file_id = os.path.basename(img_path)
|
file_id = os.path.basename(img_path)
|
||||||
img_url = f"{self.client.file_server_url}/{file_id}"
|
img_url = f"{client.file_server_url}/{file_id}"
|
||||||
logger.debug(f"gewe callback img url: {img_url}")
|
logger.debug(f"gewe callback img url: {img_url}")
|
||||||
await self.client.post_image(to_wxid, img_url)
|
await client.post_image(to_wxid, img_url)
|
||||||
|
elif isinstance(comp, Video):
|
||||||
|
if comp.cover != "":
|
||||||
|
await client.forward_video(to_wxid, comp.cover)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFmpeg
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
logger.error(
|
||||||
|
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||||
|
)
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_url = comp.file
|
||||||
|
# 根据 url 下载视频
|
||||||
|
video_filename = f"{uuid.uuid4()}.mp4"
|
||||||
|
video_path = f"data/temp/{video_filename}"
|
||||||
|
await download_file(video_url, video_path)
|
||||||
|
|
||||||
|
# 获取视频第一帧
|
||||||
|
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||||
|
try:
|
||||||
|
ff = FFmpeg()
|
||||||
|
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
|
||||||
|
ff.options(command)
|
||||||
|
thumb_file_id = os.path.basename(thumb_path)
|
||||||
|
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取视频第一帧失败: {e}")
|
||||||
|
# 获取视频时长
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFprobe
|
||||||
|
|
||||||
|
# 创建 FFprobe 实例
|
||||||
|
ffprobe = FFprobe(video_url)
|
||||||
|
# 获取时长字符串
|
||||||
|
duration_str = ffprobe.duration
|
||||||
|
# 处理时长字符串
|
||||||
|
video_duration = float(duration_str.replace(":", ""))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取时长失败: {e}")
|
||||||
|
video_duration = 10
|
||||||
|
|
||||||
|
file_id = os.path.basename(video_path)
|
||||||
|
video_url = f"{client.file_server_url}/{file_id}"
|
||||||
|
await client.post_video(
|
||||||
|
to_wxid, video_url, thumb_url, video_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除临时视频和缩略图文件
|
||||||
|
if os.path.exists(video_path):
|
||||||
|
os.remove(video_path)
|
||||||
|
if os.path.exists(thumb_path):
|
||||||
|
os.remove(thumb_path)
|
||||||
elif isinstance(comp, Record):
|
elif isinstance(comp, Record):
|
||||||
# 默认已经存在 data/temp 中
|
# 默认已经存在 data/temp 中
|
||||||
record_url = comp.file
|
record_url = comp.file
|
||||||
record_path = ""
|
record_path = await comp.convert_to_file_path()
|
||||||
|
|
||||||
if record_url.startswith("file:///"):
|
|
||||||
record_path = record_url[8:]
|
|
||||||
elif record_url.startswith("http"):
|
|
||||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
|
||||||
else:
|
|
||||||
record_path = record_url
|
|
||||||
|
|
||||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||||
|
try:
|
||||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||||
|
except Exception as e:
|
||||||
print(f"duration: {duration}, {silk_path}")
|
logger.error(traceback.format_exc())
|
||||||
|
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
||||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||||
# temp_directory = os.path.abspath('data/temp')
|
|
||||||
# record_path = os.path.abspath(record_path)
|
|
||||||
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
|
|
||||||
# with open(record_path, "rb") as f:
|
|
||||||
# record_path = f"data/temp/{uuid.uuid4()}.wav"
|
|
||||||
# with open(record_path, "wb") as f2:
|
|
||||||
# f2.write(f.read())
|
|
||||||
|
|
||||||
if duration == 0:
|
if duration == 0:
|
||||||
duration = get_wav_duration(record_path)
|
duration = get_wav_duration(record_path)
|
||||||
|
|
||||||
file_id = os.path.basename(silk_path)
|
file_id = os.path.basename(silk_path)
|
||||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
record_url = f"{client.file_server_url}/{file_id}"
|
||||||
await self.client.post_voice(to_wxid, record_url, duration*1000)
|
logger.debug(f"gewe callback record url: {record_url}")
|
||||||
|
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||||
|
elif isinstance(comp, File):
|
||||||
|
file_path = comp.file
|
||||||
|
file_name = comp.name
|
||||||
|
if file_path.startswith("file:///"):
|
||||||
|
file_path = file_path[8:]
|
||||||
|
elif file_path.startswith("http"):
|
||||||
|
await download_file(file_path, f"data/temp/{file_name}")
|
||||||
|
else:
|
||||||
|
file_path = file_path
|
||||||
|
|
||||||
|
file_id = os.path.basename(file_path)
|
||||||
|
file_url = f"{client.file_server_url}/{file_id}"
|
||||||
|
logger.debug(f"gewe callback file url: {file_url}")
|
||||||
|
await client.post_file(to_wxid, file_url, file_id)
|
||||||
|
elif isinstance(comp, Emoji):
|
||||||
|
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||||
|
elif isinstance(comp, At):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||||
|
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def get_group(self, group_id=None, **kwargs):
|
||||||
|
# 确定有效的 group_id
|
||||||
|
if group_id is None:
|
||||||
|
group_id = self.get_group_id()
|
||||||
|
|
||||||
|
if not group_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = await self.client.get_group(group_id)
|
||||||
|
data: dict = res["data"]
|
||||||
|
|
||||||
|
if not data["chatroomId"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
members = [
|
||||||
|
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||||
|
for member in data.get("memberList", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return Group(
|
||||||
|
group_id=data["chatroomId"],
|
||||||
|
group_name=data.get("nickName"),
|
||||||
|
group_avatar=data.get("smallHeadImgUrl"),
|
||||||
|
group_owner=data.get("chatRoomOwner"),
|
||||||
|
members=members,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator)
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ import os
|
|||||||
|
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
from .gewechat_event import GewechatPlatformEvent
|
from .gewechat_event import GewechatPlatformEvent
|
||||||
from .client import SimpleGewechatClient
|
from .client import SimpleGewechatClient
|
||||||
from astrbot.core.message.components import Plain
|
from astrbot import logger
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -19,45 +18,20 @@ else:
|
|||||||
|
|
||||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||||
class GewechatPlatformAdapter(Platform):
|
class GewechatPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
super().__init__(event_queue)
|
super().__init__(event_queue)
|
||||||
self.config = platform_config
|
self.config = platform_config
|
||||||
self.settingss = platform_settings
|
self.settingss = platform_settings
|
||||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
@override
|
|
||||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
|
||||||
to_wxid = session.session_id
|
|
||||||
if "_" in to_wxid:
|
|
||||||
# 群聊,开启了独立会话
|
|
||||||
_, to_wxid = to_wxid.split("_")
|
|
||||||
|
|
||||||
if not to_wxid:
|
|
||||||
logger.error("无法获取到 to_wxid。")
|
|
||||||
return
|
|
||||||
|
|
||||||
for comp in message_chain.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
await self.client.post_text(to_wxid, comp.text)
|
|
||||||
|
|
||||||
await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
return PlatformMetadata(
|
|
||||||
"gewechat",
|
|
||||||
"基于 gewechat 的 Wechat 适配器",
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def run(self):
|
|
||||||
self.client = SimpleGewechatClient(
|
self.client = SimpleGewechatClient(
|
||||||
self.config['base_url'],
|
self.config["base_url"],
|
||||||
self.config['nickname'],
|
self.config["nickname"],
|
||||||
self.config['host'],
|
self.config["host"],
|
||||||
self.config['port'],
|
self.config["port"],
|
||||||
self._event_queue,
|
self._event_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,28 +40,61 @@ class GewechatPlatformAdapter(Platform):
|
|||||||
|
|
||||||
self.client.on_event_received = on_event_received
|
self.client.on_event_received = on_event_received
|
||||||
|
|
||||||
return self._run()
|
@override
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
session_id = session.session_id
|
||||||
|
if "#" in session_id:
|
||||||
|
# unique session
|
||||||
|
to_wxid = session_id.split("#")[1]
|
||||||
|
else:
|
||||||
|
to_wxid = session_id
|
||||||
|
|
||||||
|
await GewechatPlatformEvent.send_with_client(
|
||||||
|
message_chain, to_wxid, self.client
|
||||||
|
)
|
||||||
|
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="gewechat",
|
||||||
|
description="基于 gewechat 的 Wechat 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.client.shutdown_event.set()
|
||||||
|
await self.client.server.shutdown()
|
||||||
|
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
await self.client.logout()
|
await self.client.logout()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def run(self):
|
||||||
|
return self._run()
|
||||||
|
|
||||||
async def _run(self):
|
async def _run(self):
|
||||||
await self.client.login()
|
await self.client.login()
|
||||||
|
|
||||||
await self.client.start_polling()
|
await self.client.start_polling()
|
||||||
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
if message.type == MessageType.GROUP_MESSAGE:
|
if message.type == MessageType.GROUP_MESSAGE:
|
||||||
if self.settingss['unique_session']:
|
if self.settingss["unique_session"]:
|
||||||
message.session_id = message.sender.user_id + "_" + message.group_id
|
message.session_id = message.sender.user_id + "#" + message.group_id
|
||||||
|
|
||||||
message_event = GewechatPlatformEvent(
|
message_event = GewechatPlatformEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
message_obj=message,
|
message_obj=message,
|
||||||
platform_meta=self.meta(),
|
platform_meta=self.meta(),
|
||||||
session_id=message.session_id,
|
session_id=message.session_id,
|
||||||
client=self.client
|
client=self.client,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
def get_client(self) -> SimpleGewechatClient:
|
||||||
|
return self.client
|
||||||
|
|||||||
78
astrbot/core/platform/sources/gewechat/xml_data_parser.py
Normal file
78
astrbot/core/platform/sources/gewechat/xml_data_parser.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from defusedxml import ElementTree as eT
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
|
||||||
|
|
||||||
|
|
||||||
|
class GeweDataParser:
|
||||||
|
def __init__(self, data, is_private_chat):
|
||||||
|
self.data = data
|
||||||
|
self.is_private_chat = is_private_chat
|
||||||
|
|
||||||
|
def _format_to_xml(self):
|
||||||
|
return eT.fromstring(self.data)
|
||||||
|
|
||||||
|
def parse_mutil_49(self):
|
||||||
|
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||||
|
if appmsg_type is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
match appmsg_type.text:
|
||||||
|
case "57":
|
||||||
|
return self.parse_reply()
|
||||||
|
|
||||||
|
def parse_emoji(self) -> Emoji | None:
|
||||||
|
try:
|
||||||
|
emoji_element = self._format_to_xml().find(".//emoji")
|
||||||
|
# 提取 md5 和 len 属性
|
||||||
|
if emoji_element is not None:
|
||||||
|
md5_value = emoji_element.get("md5")
|
||||||
|
emoji_size = emoji_element.get("len")
|
||||||
|
cdnurl = emoji_element.get("cdnurl")
|
||||||
|
|
||||||
|
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||||
|
|
||||||
|
def parse_reply(self) -> Reply | None:
|
||||||
|
try:
|
||||||
|
replied_id = -1
|
||||||
|
replied_uid = 0
|
||||||
|
replied_nickname = ""
|
||||||
|
replied_content = ""
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
root = self._format_to_xml()
|
||||||
|
refermsg = root.find(".//refermsg")
|
||||||
|
if refermsg is not None:
|
||||||
|
# 被引用的信息
|
||||||
|
svrid = refermsg.find("svrid")
|
||||||
|
fromusr = refermsg.find("fromusr")
|
||||||
|
displayname = refermsg.find("displayname")
|
||||||
|
refermsg_content = refermsg.find("content")
|
||||||
|
if svrid is not None:
|
||||||
|
replied_id = svrid.text
|
||||||
|
if fromusr is not None:
|
||||||
|
replied_uid = fromusr.text
|
||||||
|
if displayname is not None:
|
||||||
|
replied_nickname = displayname.text
|
||||||
|
if refermsg_content is not None:
|
||||||
|
replied_content = refermsg_content.text
|
||||||
|
|
||||||
|
# 提取引用者说的内容
|
||||||
|
title = root.find(".//appmsg/title")
|
||||||
|
if title is not None:
|
||||||
|
content = title.text
|
||||||
|
|
||||||
|
r = Reply(
|
||||||
|
id=replied_id,
|
||||||
|
chain=[Plain(content)],
|
||||||
|
sender_id=replied_uid,
|
||||||
|
sender_nickname=replied_nickname,
|
||||||
|
sender_str=replied_content,
|
||||||
|
message_str=content,
|
||||||
|
)
|
||||||
|
return r
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||||
232
astrbot/core/platform/sources/lark/lark_adapter.py
Normal file
232
astrbot/core/platform/sources/lark/lark_adapter.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from .lark_event import LarkMessageEvent
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from astrbot import logger
|
||||||
|
import lark_oapi as lark
|
||||||
|
from lark_oapi.api.im.v1 import *
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
|
||||||
|
class LarkPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
|
||||||
|
self.config = platform_config
|
||||||
|
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
self.appid = platform_config["app_id"]
|
||||||
|
self.appsecret = platform_config["app_secret"]
|
||||||
|
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
|
||||||
|
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
|
||||||
|
|
||||||
|
if not self.bot_name:
|
||||||
|
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||||
|
|
||||||
|
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||||
|
await self.convert_msg(event)
|
||||||
|
|
||||||
|
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||||
|
asyncio.create_task(on_msg_event_recv(event))
|
||||||
|
|
||||||
|
self.event_handler = (
|
||||||
|
lark.EventDispatcherHandler.builder("", "")
|
||||||
|
.register_p2_im_message_receive_v1(do_v2_msg_event)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = lark.ws.Client(
|
||||||
|
app_id=self.appid,
|
||||||
|
app_secret=self.appsecret,
|
||||||
|
log_level=lark.LogLevel.ERROR,
|
||||||
|
domain=self.domain,
|
||||||
|
event_handler=self.event_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lark_api = (
|
||||||
|
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||||
|
wrapped = {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "",
|
||||||
|
"content": res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||||
|
id_type = "chat_id"
|
||||||
|
if "%" in session.session_id:
|
||||||
|
session.session_id = session.session_id.split("%")[1]
|
||||||
|
else:
|
||||||
|
id_type = "open_id"
|
||||||
|
|
||||||
|
request = (
|
||||||
|
CreateMessageRequest.builder()
|
||||||
|
.receive_id_type(id_type)
|
||||||
|
.request_body(
|
||||||
|
CreateMessageRequestBody.builder()
|
||||||
|
.receive_id(session.session_id)
|
||||||
|
.content(json.dumps(wrapped))
|
||||||
|
.msg_type("post")
|
||||||
|
.uuid(str(uuid.uuid4()))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.lark_api.im.v1.message.acreate(request)
|
||||||
|
|
||||||
|
if not response.success():
|
||||||
|
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||||
|
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="lark",
|
||||||
|
description="飞书机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||||
|
message = event.event.message
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.timestamp = int(message.create_time) / 1000
|
||||||
|
abm.message = []
|
||||||
|
abm.type = (
|
||||||
|
MessageType.GROUP_MESSAGE
|
||||||
|
if message.chat_type == "group"
|
||||||
|
else MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
if message.chat_type == "group":
|
||||||
|
abm.group_id = message.chat_id
|
||||||
|
abm.self_id = self.bot_name
|
||||||
|
abm.message_str = ""
|
||||||
|
|
||||||
|
at_list = {}
|
||||||
|
if message.mentions:
|
||||||
|
for m in message.mentions:
|
||||||
|
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||||
|
if m.name == self.bot_name:
|
||||||
|
abm.self_id = m.id.open_id
|
||||||
|
|
||||||
|
content_json_b = json.loads(message.content)
|
||||||
|
|
||||||
|
if message.message_type == "text":
|
||||||
|
message_str_raw = content_json_b["text"] # 带有 @ 的消息
|
||||||
|
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||||
|
# at_users = re.findall(at_pattern, message_str_raw)
|
||||||
|
# 拆分文本,去掉AT符号部分
|
||||||
|
parts = re.split(at_pattern, message_str_raw)
|
||||||
|
for i in range(len(parts)):
|
||||||
|
s = parts[i].strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
if s in at_list:
|
||||||
|
abm.message.append(at_list[s])
|
||||||
|
else:
|
||||||
|
abm.message.append(Comp.Plain(parts[i].strip()))
|
||||||
|
elif message.message_type == "post":
|
||||||
|
_ls = []
|
||||||
|
|
||||||
|
content_ls = content_json_b.get("content", [])
|
||||||
|
for comp in content_ls:
|
||||||
|
if isinstance(comp, list):
|
||||||
|
_ls.extend(comp)
|
||||||
|
elif isinstance(comp, dict):
|
||||||
|
_ls.append(comp)
|
||||||
|
content_json_b = _ls
|
||||||
|
elif message.message_type == "image":
|
||||||
|
content_json_b = [
|
||||||
|
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}
|
||||||
|
]
|
||||||
|
|
||||||
|
if message.message_type in ("post", "image"):
|
||||||
|
for comp in content_json_b:
|
||||||
|
if comp["tag"] == "at":
|
||||||
|
abm.message.append(at_list[comp["user_id"]])
|
||||||
|
elif comp["tag"] == "text" and comp["text"].strip():
|
||||||
|
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||||
|
elif comp["tag"] == "img":
|
||||||
|
image_key = comp["image_key"]
|
||||||
|
request = (
|
||||||
|
GetMessageResourceRequest.builder()
|
||||||
|
.message_id(message.message_id)
|
||||||
|
.file_key(image_key)
|
||||||
|
.type("image")
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||||
|
if not response.success():
|
||||||
|
logger.error(f"无法下载飞书图片: {image_key}")
|
||||||
|
image_bytes = response.file.read()
|
||||||
|
image_base64 = base64.b64encode(image_bytes).decode()
|
||||||
|
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||||
|
|
||||||
|
for comp in abm.message:
|
||||||
|
if isinstance(comp, Comp.Plain):
|
||||||
|
abm.message_str += comp.text
|
||||||
|
abm.message_id = message.message_id
|
||||||
|
abm.raw_message = message
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
user_id=event.event.sender.sender_id.open_id,
|
||||||
|
nickname=event.event.sender.sender_id.open_id[:8],
|
||||||
|
)
|
||||||
|
# 独立会话
|
||||||
|
if not self.unique_session:
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = abm.group_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
else:
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
|
logger.debug(abm)
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
async def handle_msg(self, abm: AstrBotMessage):
|
||||||
|
event = LarkMessageEvent(
|
||||||
|
message_str=abm.message_str,
|
||||||
|
message_obj=abm,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=abm.session_id,
|
||||||
|
bot=self.lark_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_queue.put_nowait(event)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
# self.client.start()
|
||||||
|
await self.client._connect()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
await self.client._disconnect()
|
||||||
|
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||||
|
|
||||||
|
def get_client(self) -> lark.Client:
|
||||||
|
return self.client
|
||||||
106
astrbot/core/platform/sources/lark/lark_event.py
Normal file
106
astrbot/core/platform/sources/lark/lark_event.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import lark_oapi as lark
|
||||||
|
from typing import List
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||||
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from lark_oapi.api.im.v1 import *
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
|
class LarkMessageEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self, message_str, message_obj, platform_meta, session_id, bot: lark.Client
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List:
|
||||||
|
ret = []
|
||||||
|
_stage = []
|
||||||
|
for comp in message.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
_stage.append({"tag": "md", "text": comp.text})
|
||||||
|
elif isinstance(comp, At):
|
||||||
|
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
||||||
|
elif isinstance(comp, AstrBotImage):
|
||||||
|
file_path = ""
|
||||||
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
|
file_path = comp.file.replace("file:///", "")
|
||||||
|
elif comp.file and comp.file.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(comp.file)
|
||||||
|
file_path = image_file_path
|
||||||
|
elif comp.file and comp.file.startswith("base64://"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
file_path = comp.file
|
||||||
|
|
||||||
|
request = (
|
||||||
|
CreateImageRequest.builder()
|
||||||
|
.request_body(
|
||||||
|
CreateImageRequestBody.builder()
|
||||||
|
.image_type("message")
|
||||||
|
.image(open(file_path, "rb"))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
response = await lark_client.im.v1.image.acreate(request)
|
||||||
|
if not response.success():
|
||||||
|
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||||
|
image_key = response.data.image_key
|
||||||
|
print(image_key)
|
||||||
|
ret.append(_stage)
|
||||||
|
ret.append([{"tag": "img", "image_key": image_key}])
|
||||||
|
_stage.clear()
|
||||||
|
else:
|
||||||
|
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
|
||||||
|
|
||||||
|
if _stage:
|
||||||
|
ret.append(_stage)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
|
||||||
|
wrapped = {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "",
|
||||||
|
"content": res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request = (
|
||||||
|
ReplyMessageRequest.builder()
|
||||||
|
.message_id(self.message_obj.message_id)
|
||||||
|
.request_body(
|
||||||
|
ReplyMessageRequestBody.builder()
|
||||||
|
.content(json.dumps(wrapped))
|
||||||
|
.msg_type("post")
|
||||||
|
.uuid(str(uuid.uuid4()))
|
||||||
|
.reply_in_thread(False)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.bot.im.v1.message.areply(request)
|
||||||
|
|
||||||
|
if not response.success():
|
||||||
|
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||||
|
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator)
|
||||||
@@ -2,67 +2,190 @@ import botpy
|
|||||||
import botpy.message
|
import botpy.message
|
||||||
import botpy.types
|
import botpy.types
|
||||||
import botpy.types.message
|
import botpy.types.message
|
||||||
|
import asyncio
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
from astrbot.api.message_components import Plain, Image
|
from astrbot.api.message_components import Plain, Image
|
||||||
from botpy import Client
|
from botpy import Client
|
||||||
from botpy.http import Route
|
from botpy.http import Route
|
||||||
|
from astrbot.api import logger
|
||||||
|
from botpy.types import message
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
bot: Client,
|
||||||
|
):
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.send_buffer = None
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
source = self.message_obj.raw_message
|
if not self.send_buffer:
|
||||||
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
self.send_buffer = message
|
||||||
|
else:
|
||||||
|
self.send_buffer.chain.extend(message.chain)
|
||||||
|
|
||||||
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
|
async def send_streaming(self, generator):
|
||||||
|
"""流式输出仅支持消息列表私聊"""
|
||||||
|
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||||
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
|
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||||
|
try:
|
||||||
|
async for chain in generator:
|
||||||
|
source = self.message_obj.raw_message
|
||||||
|
if not self.send_buffer:
|
||||||
|
self.send_buffer = chain
|
||||||
|
else:
|
||||||
|
self.send_buffer.chain.extend(chain.chain)
|
||||||
|
|
||||||
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
|
# 真流式传输
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
|
if time_since_last_edit >= throttle_interval:
|
||||||
|
ret = await self._post_send(stream=stream_payload)
|
||||||
|
stream_payload["index"] += 1
|
||||||
|
stream_payload["id"] = ret["id"]
|
||||||
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
|
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||||
|
stream_payload["state"] = 10
|
||||||
|
ret = await self._post_send(stream=stream_payload)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||||
|
self.send_buffer = None
|
||||||
|
|
||||||
|
return await super().send_streaming(generator)
|
||||||
|
|
||||||
|
async def _post_send(self, stream: dict = None):
|
||||||
|
if not self.send_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
source = self.message_obj.raw_message
|
||||||
|
assert isinstance(
|
||||||
|
source,
|
||||||
|
(
|
||||||
|
botpy.message.Message,
|
||||||
|
botpy.message.GroupMessage,
|
||||||
|
botpy.message.DirectMessage,
|
||||||
|
botpy.message.C2CMessage,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
plain_text,
|
||||||
|
image_base64,
|
||||||
|
image_path,
|
||||||
|
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||||
|
|
||||||
|
if not plain_text and not image_base64 and not image_path:
|
||||||
|
return
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
'content': plain_text,
|
"content": plain_text,
|
||||||
'msg_id': self.message_obj.message_id,
|
"msg_id": self.message_obj.message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
|
||||||
|
payload["msg_seq"] = random.randint(1, 10000)
|
||||||
|
|
||||||
match type(source):
|
match type(source):
|
||||||
case botpy.message.GroupMessage:
|
case botpy.message.GroupMessage:
|
||||||
if image_base64:
|
if image_base64:
|
||||||
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
|
media = await self.upload_group_and_c2c_image(
|
||||||
payload['media'] = media
|
image_base64, 1, group_openid=source.group_openid
|
||||||
payload['msg_type'] = 7
|
)
|
||||||
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
|
payload["media"] = media
|
||||||
|
payload["msg_type"] = 7
|
||||||
|
ret = await self.bot.api.post_group_message(
|
||||||
|
group_openid=source.group_openid, **payload
|
||||||
|
)
|
||||||
case botpy.message.C2CMessage:
|
case botpy.message.C2CMessage:
|
||||||
if image_base64:
|
if image_base64:
|
||||||
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
|
media = await self.upload_group_and_c2c_image(
|
||||||
payload['media'] = media
|
image_base64, 1, openid=source.author.user_openid
|
||||||
payload['msg_type'] = 7
|
)
|
||||||
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
|
payload["media"] = media
|
||||||
|
payload["msg_type"] = 7
|
||||||
|
if stream:
|
||||||
|
ret = await self.post_c2c_message(
|
||||||
|
openid=source.author.user_openid,
|
||||||
|
**payload,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ret = await self.post_c2c_message(
|
||||||
|
openid=source.author.user_openid, **payload
|
||||||
|
)
|
||||||
|
logger.debug(f"Message sent to C2C: {ret}")
|
||||||
case botpy.message.Message:
|
case botpy.message.Message:
|
||||||
if image_path:
|
if image_path:
|
||||||
payload['file_image'] = image_path
|
payload["file_image"] = image_path
|
||||||
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
|
ret = await self.bot.api.post_message(
|
||||||
|
channel_id=source.channel_id, **payload
|
||||||
|
)
|
||||||
case botpy.message.DirectMessage:
|
case botpy.message.DirectMessage:
|
||||||
if image_path:
|
if image_path:
|
||||||
payload['file_image'] = image_path
|
payload["file_image"] = image_path
|
||||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(self.send_buffer)
|
||||||
|
|
||||||
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
|
self.send_buffer = None
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def upload_group_and_c2c_image(
|
||||||
|
self, image_base64: str, file_type: int, **kwargs
|
||||||
|
) -> botpy.types.message.Media:
|
||||||
payload = {
|
payload = {
|
||||||
'file_data': image_base64,
|
"file_data": image_base64,
|
||||||
'file_type': file_type,
|
"file_type": file_type,
|
||||||
"srv_send_msg": False
|
"srv_send_msg": False,
|
||||||
}
|
}
|
||||||
if 'openid' in kwargs:
|
if "openid" in kwargs:
|
||||||
payload['openid'] = kwargs['openid']
|
payload["openid"] = kwargs["openid"]
|
||||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs['openid'])
|
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
elif 'group_openid' in kwargs:
|
elif "group_openid" in kwargs:
|
||||||
payload['group_openid'] = kwargs['group_openid']
|
payload["group_openid"] = kwargs["group_openid"]
|
||||||
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs['group_openid'])
|
route = Route(
|
||||||
|
"POST",
|
||||||
|
"/v2/groups/{group_openid}/files",
|
||||||
|
group_openid=kwargs["group_openid"],
|
||||||
|
)
|
||||||
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
|
async def post_c2c_message(
|
||||||
|
self,
|
||||||
|
openid: str,
|
||||||
|
msg_type: int = 0,
|
||||||
|
content: str = None,
|
||||||
|
embed: message.Embed = None,
|
||||||
|
ark: message.Ark = None,
|
||||||
|
message_reference: message.Reference = None,
|
||||||
|
media: message.Media = None,
|
||||||
|
msg_id: str = None,
|
||||||
|
msg_seq: str = 1,
|
||||||
|
event_id: str = None,
|
||||||
|
markdown: message.MarkdownPayload = None,
|
||||||
|
keyboard: message.Keyboard = None,
|
||||||
|
stream: dict = None,
|
||||||
|
) -> message.Message:
|
||||||
|
payload = locals()
|
||||||
|
payload.pop("self", None)
|
||||||
|
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -75,12 +198,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
plain_text += i.text
|
plain_text += i.text
|
||||||
elif isinstance(i, Image) and not image_base64:
|
elif isinstance(i, Image) and not image_base64:
|
||||||
if i.file and i.file.startswith("file:///"):
|
if i.file and i.file.startswith("file:///"):
|
||||||
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
|
image_base64 = file_to_base64(i.file[8:])
|
||||||
image_file_path = i.file[8:]
|
image_file_path = i.file[8:]
|
||||||
elif i.file and i.file.startswith("http"):
|
elif i.file and i.file.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(i.file)
|
image_file_path = await download_image_by_url(i.file)
|
||||||
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
|
image_base64 = file_to_base64(image_file_path)
|
||||||
|
elif i.file and i.file.startswith("base64://"):
|
||||||
|
image_base64 = i.file
|
||||||
else:
|
else:
|
||||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
image_base64 = file_to_base64(i.file)
|
||||||
image_file_path = i.file
|
image_base64 = image_base64.removeprefix("base64://")
|
||||||
|
else:
|
||||||
|
logger.debug(f"qq_official 忽略 {i.type}")
|
||||||
return plain_text, image_base64, image_file_path
|
return plain_text, image_base64, image_file_path
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import botpy
|
import botpy
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -8,7 +10,14 @@ import botpy.types.message
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from botpy import Client
|
from botpy import Client
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot import logger
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
from astrbot.api.message_components import Image, Plain, At
|
from astrbot.api.message_components import Image, Plain, At
|
||||||
@@ -21,67 +30,84 @@ from astrbot.core.message.components import BaseMessageComponent
|
|||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
# QQ 机器人官方框架
|
# QQ 机器人官方框架
|
||||||
class botClient(Client):
|
class botClient(Client):
|
||||||
def set_platform(self, platform: 'QQOfficialPlatformAdapter'):
|
def set_platform(self, platform: "QQOfficialPlatformAdapter"):
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
|
|
||||||
# 收到群消息
|
# 收到群消息
|
||||||
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
||||||
abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid
|
message, MessageType.GROUP_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = (
|
||||||
|
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||||
|
)
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
# 收到频道消息
|
# 收到频道消息
|
||||||
async def on_at_message_create(self, message: botpy.message.Message):
|
async def on_at_message_create(self, message: botpy.message.Message):
|
||||||
abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id
|
message, MessageType.GROUP_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = (
|
||||||
|
abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||||
|
)
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
# 收到私聊消息
|
# 收到私聊消息
|
||||||
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
||||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
|
message, MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
# 收到 C2C 消息
|
# 收到 C2C 消息
|
||||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
|
message, MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
def _commit(self, abm: AstrBotMessage):
|
def _commit(self, abm: AstrBotMessage):
|
||||||
self.platform.commit_event(QQOfficialMessageEvent(
|
self.platform.commit_event(
|
||||||
|
QQOfficialMessageEvent(
|
||||||
abm.message_str,
|
abm.message_str,
|
||||||
abm,
|
abm,
|
||||||
self.platform.meta(),
|
self.platform.meta(),
|
||||||
abm.session_id,
|
abm.session_id,
|
||||||
self.platform.client
|
self.platform.client,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器")
|
@register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器")
|
||||||
class QQOfficialPlatformAdapter(Platform):
|
class QQOfficialPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
super().__init__(event_queue)
|
super().__init__(event_queue)
|
||||||
|
|
||||||
self.config = platform_config
|
self.config = platform_config
|
||||||
|
|
||||||
self.appid = platform_config['appid']
|
self.appid = platform_config["appid"]
|
||||||
self.secret = platform_config['secret']
|
self.secret = platform_config["secret"]
|
||||||
self.unique_session = platform_settings['unique_session']
|
self.unique_session = platform_settings["unique_session"]
|
||||||
qq_group = platform_config['enable_group_c2c']
|
qq_group = platform_config["enable_group_c2c"]
|
||||||
guild_dm = platform_config['enable_guild_direct_message']
|
guild_dm = platform_config["enable_guild_direct_message"]
|
||||||
|
|
||||||
if qq_group:
|
if qq_group:
|
||||||
self.intents = botpy.Intents(
|
self.intents = botpy.Intents(
|
||||||
public_messages=True,
|
public_messages=True,
|
||||||
public_guild_messages=True,
|
public_guild_messages=True,
|
||||||
direct_message=guild_dm
|
direct_message=guild_dm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.intents = botpy.Intents(
|
self.intents = botpy.Intents(
|
||||||
public_guild_messages=True,
|
public_guild_messages=True, direct_message=guild_dm
|
||||||
direct_message=guild_dm
|
|
||||||
)
|
)
|
||||||
self.client = botClient(
|
self.client = botClient(
|
||||||
intents=self.intents,
|
intents=self.intents,
|
||||||
@@ -91,19 +117,25 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
self.client.set_platform(self)
|
self.client.set_platform(self)
|
||||||
|
|
||||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||||
|
|
||||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"qq_official",
|
name="qq_official",
|
||||||
"QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage],
|
@staticmethod
|
||||||
message_type: MessageType):
|
def _parse_from_qqofficial(
|
||||||
|
message: Union[botpy.message.Message, botpy.message.GroupMessage],
|
||||||
|
message_type: MessageType,
|
||||||
|
):
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.type = message_type
|
abm.type = message_type
|
||||||
abm.timestamp = int(time.time())
|
abm.timestamp = int(time.time())
|
||||||
@@ -112,19 +144,14 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
abm.tag = "qq_official"
|
abm.tag = "qq_official"
|
||||||
msg: List[BaseMessageComponent] = []
|
msg: List[BaseMessageComponent] = []
|
||||||
|
|
||||||
|
if isinstance(message, botpy.message.GroupMessage) or isinstance(
|
||||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
|
message, botpy.message.C2CMessage
|
||||||
|
):
|
||||||
if isinstance(message, botpy.message.GroupMessage):
|
if isinstance(message, botpy.message.GroupMessage):
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(message.author.member_openid, "")
|
||||||
message.author.member_openid,
|
|
||||||
""
|
|
||||||
)
|
|
||||||
abm.group_id = message.group_openid
|
abm.group_id = message.group_openid
|
||||||
else:
|
else:
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(message.author.user_openid, "")
|
||||||
message.author.user_openid,
|
|
||||||
""
|
|
||||||
)
|
|
||||||
abm.message_str = message.content.strip()
|
abm.message_str = message.content.strip()
|
||||||
abm.self_id = "unknown_selfid"
|
abm.self_id = "unknown_selfid"
|
||||||
msg.append(At(qq="qq_official"))
|
msg.append(At(qq="qq_official"))
|
||||||
@@ -134,33 +161,35 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
if i.content_type.startswith("image"):
|
if i.content_type.startswith("image"):
|
||||||
url = i.url
|
url = i.url
|
||||||
if not url.startswith("http"):
|
if not url.startswith("http"):
|
||||||
url = "https://"+url
|
url = "https://" + url
|
||||||
img = Image.fromURL(url)
|
img = Image.fromURL(url)
|
||||||
msg.append(img)
|
msg.append(img)
|
||||||
abm.message = msg
|
abm.message = msg
|
||||||
|
|
||||||
elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage):
|
elif isinstance(message, botpy.message.Message) or isinstance(
|
||||||
|
message, botpy.message.DirectMessage
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
abm.self_id = str(message.mentions[0].id)
|
abm.self_id = str(message.mentions[0].id)
|
||||||
except BaseException as _:
|
except BaseException as _:
|
||||||
abm.self_id = ""
|
abm.self_id = ""
|
||||||
|
|
||||||
plain_content = message.content.replace(
|
plain_content = message.content.replace(
|
||||||
"<@!"+str(abm.self_id)+">", "").strip()
|
"<@!" + str(abm.self_id) + ">", ""
|
||||||
|
).strip()
|
||||||
|
|
||||||
if message.attachments:
|
if message.attachments:
|
||||||
for i in message.attachments:
|
for i in message.attachments:
|
||||||
if i.content_type.startswith("image"):
|
if i.content_type.startswith("image"):
|
||||||
url = i.url
|
url = i.url
|
||||||
if not url.startswith("http"):
|
if not url.startswith("http"):
|
||||||
url = "https://"+url
|
url = "https://" + url
|
||||||
img = Image.fromURL(url)
|
img = Image.fromURL(url)
|
||||||
msg.append(img)
|
msg.append(img)
|
||||||
abm.message = msg
|
abm.message = msg
|
||||||
abm.message_str = plain_content
|
abm.message_str = plain_content
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
str(message.author.id),
|
str(message.author.id), str(message.author.username)
|
||||||
str(message.author.username)
|
|
||||||
)
|
)
|
||||||
msg.append(At(qq="qq_official"))
|
msg.append(At(qq="qq_official"))
|
||||||
msg.append(Plain(plain_content))
|
msg.append(Plain(plain_content))
|
||||||
@@ -173,7 +202,11 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
return abm
|
return abm
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
return self.client.start(
|
return self.client.start(appid=self.appid, secret=self.secret)
|
||||||
appid=self.appid,
|
|
||||||
secret=self.secret
|
def get_client(self) -> botClient:
|
||||||
)
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
await self.client.close()
|
||||||
|
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
|
||||||
|
|||||||
@@ -0,0 +1,124 @@
|
|||||||
|
import botpy
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import botpy.message
|
||||||
|
import botpy.types
|
||||||
|
import botpy.types.message
|
||||||
|
|
||||||
|
from botpy import Client
|
||||||
|
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from .qo_webhook_event import QQOfficialWebhookMessageEvent
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from .qo_webhook_server import QQOfficialWebhook
|
||||||
|
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
# remove logger handler
|
||||||
|
for handler in logging.root.handlers[:]:
|
||||||
|
logging.root.removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
# QQ 机器人官方框架
|
||||||
|
class botClient(Client):
|
||||||
|
def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"):
|
||||||
|
self.platform = platform
|
||||||
|
|
||||||
|
# 收到群消息
|
||||||
|
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
||||||
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
|
message, MessageType.GROUP_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = (
|
||||||
|
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||||
|
)
|
||||||
|
self._commit(abm)
|
||||||
|
|
||||||
|
# 收到频道消息
|
||||||
|
async def on_at_message_create(self, message: botpy.message.Message):
|
||||||
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
|
message, MessageType.GROUP_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = (
|
||||||
|
abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||||
|
)
|
||||||
|
self._commit(abm)
|
||||||
|
|
||||||
|
# 收到私聊消息
|
||||||
|
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
||||||
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
|
message, MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
self._commit(abm)
|
||||||
|
|
||||||
|
# 收到 C2C 消息
|
||||||
|
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||||
|
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||||
|
message, MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
self._commit(abm)
|
||||||
|
|
||||||
|
def _commit(self, abm: AstrBotMessage):
|
||||||
|
self.platform.commit_event(
|
||||||
|
QQOfficialWebhookMessageEvent(
|
||||||
|
abm.message_str, abm, self.platform.meta(), abm.session_id, self
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("qq_official_webhook", "QQ 机器人官方 API 适配器(Webhook)")
|
||||||
|
class QQOfficialWebhookPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
|
||||||
|
self.config = platform_config
|
||||||
|
|
||||||
|
self.appid = platform_config["appid"]
|
||||||
|
self.secret = platform_config["secret"]
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
intents = botpy.Intents(
|
||||||
|
public_messages=True, public_guild_messages=True, direct_message=True
|
||||||
|
)
|
||||||
|
self.client = botClient(
|
||||||
|
intents=intents, # 已经无用
|
||||||
|
bot_log=False,
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
self.client.set_platform(self)
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="qq_official_webhook",
|
||||||
|
description="QQ 机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
self.webhook_helper = QQOfficialWebhook(
|
||||||
|
self.config, self._event_queue, self.client
|
||||||
|
)
|
||||||
|
await self.webhook_helper.initialize()
|
||||||
|
await self.webhook_helper.start_polling()
|
||||||
|
|
||||||
|
def get_client(self) -> botClient:
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.webhook_helper.shutdown_event.set()
|
||||||
|
await self.client.close()
|
||||||
|
try:
|
||||||
|
await self.webhook_helper.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
|
from botpy import Client
|
||||||
|
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
|
||||||
|
|
||||||
|
|
||||||
|
class QQOfficialWebhookMessageEvent(QQOfficialMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
bot: Client,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id, bot)
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
import quart
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from botpy import BotAPI, BotHttp, Client, Token, BotWebSocket, ConnectionSession
|
||||||
|
from astrbot.api import logger
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
|
|
||||||
|
# remove logger handler
|
||||||
|
for handler in logging.root.handlers[:]:
|
||||||
|
logging.root.removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
class QQOfficialWebhook:
|
||||||
|
def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client):
|
||||||
|
self.appid = config["appid"]
|
||||||
|
self.secret = config["secret"]
|
||||||
|
self.port = config.get("port", 6196)
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
|
|
||||||
|
if isinstance(self.port, str):
|
||||||
|
self.port = int(self.port)
|
||||||
|
|
||||||
|
self.http: BotHttp = BotHttp(timeout=300)
|
||||||
|
self.api: BotAPI = BotAPI(http=self.http)
|
||||||
|
self.token = Token(self.appid, self.secret)
|
||||||
|
|
||||||
|
self.server = quart.Quart(__name__)
|
||||||
|
self.server.add_url_rule(
|
||||||
|
"/astrbot-qo-webhook/callback", view_func=self.callback, methods=["POST"]
|
||||||
|
)
|
||||||
|
self.client = botpy_client
|
||||||
|
self.event_queue = event_queue
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
logger.info("正在登录到 QQ 官方机器人...")
|
||||||
|
self.user = await self.http.login(self.token)
|
||||||
|
logger.info(f"已登录 QQ 官方机器人账号: {self.user}")
|
||||||
|
# 直接注入到 botpy 的 Client,移花接木!
|
||||||
|
self.client.api = self.api
|
||||||
|
self.client.http = self.http
|
||||||
|
|
||||||
|
async def bot_connect():
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._connection = ConnectionSession(
|
||||||
|
max_async=1,
|
||||||
|
connect=bot_connect,
|
||||||
|
dispatch=self.client.ws_dispatch,
|
||||||
|
loop=asyncio.get_event_loop(),
|
||||||
|
api=self.api,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def repeat_seed(self, bot_secret: str, target_size: int = 32) -> bytes:
|
||||||
|
seed = bot_secret
|
||||||
|
while len(seed) < target_size:
|
||||||
|
seed *= 2
|
||||||
|
return seed[:target_size].encode("utf-8")
|
||||||
|
|
||||||
|
async def webhook_validation(self, validation_payload: dict):
|
||||||
|
seed = await self.repeat_seed(self.secret)
|
||||||
|
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||||
|
msg = validation_payload.get("event_ts", "") + validation_payload.get(
|
||||||
|
"plain_token", ""
|
||||||
|
)
|
||||||
|
# sign
|
||||||
|
signature = private_key.sign(msg.encode()).hex()
|
||||||
|
response = {
|
||||||
|
"plain_token": validation_payload.get("plain_token"),
|
||||||
|
"signature": signature,
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def callback(self):
|
||||||
|
msg: dict = await quart.request.json
|
||||||
|
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
|
||||||
|
|
||||||
|
event = msg.get("t")
|
||||||
|
opcode = msg.get("op")
|
||||||
|
data = msg.get("d")
|
||||||
|
|
||||||
|
if opcode == 13:
|
||||||
|
# validation
|
||||||
|
signed = await self.webhook_validation(data)
|
||||||
|
print(signed)
|
||||||
|
return signed
|
||||||
|
|
||||||
|
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
|
||||||
|
event = msg["t"].lower()
|
||||||
|
try:
|
||||||
|
func = self._connection.parser[event]
|
||||||
|
except KeyError:
|
||||||
|
logger.error("_parser unknown event %s.", event)
|
||||||
|
else:
|
||||||
|
func(msg)
|
||||||
|
|
||||||
|
return {"opcode": 12}
|
||||||
|
|
||||||
|
async def start_polling(self):
|
||||||
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。"
|
||||||
|
)
|
||||||
|
await self.server.run_task(
|
||||||
|
host=self.callback_server_host,
|
||||||
|
port=self.port,
|
||||||
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown_trigger(self):
|
||||||
|
await self.shutdown_event.wait()
|
||||||
350
astrbot/core/platform/sources/telegram/tg_adapter.py
Normal file
350
astrbot/core/platform/sources/telegram/tg_adapter.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from telegram import BotCommand, Update
|
||||||
|
from telegram.constants import ChatType
|
||||||
|
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
|
||||||
|
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||||
|
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
Platform,
|
||||||
|
PlatformMetadata,
|
||||||
|
register_platform_adapter,
|
||||||
|
)
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from astrbot.core.star.filter.command import CommandFilter
|
||||||
|
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
|
|
||||||
|
from .tg_event import TelegramPlatformEvent
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import override
|
||||||
|
else:
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("telegram", "telegram 适配器")
|
||||||
|
class TelegramPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self.config = platform_config
|
||||||
|
self.settings = platform_settings
|
||||||
|
self.client_self_id = uuid.uuid4().hex[:8]
|
||||||
|
|
||||||
|
base_url = self.config.get(
|
||||||
|
"telegram_api_base_url", "https://api.telegram.org/bot"
|
||||||
|
)
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://api.telegram.org/bot"
|
||||||
|
|
||||||
|
file_base_url = self.config.get(
|
||||||
|
"telegram_file_base_url", "https://api.telegram.org/file/bot"
|
||||||
|
)
|
||||||
|
if not file_base_url:
|
||||||
|
file_base_url = "https://api.telegram.org/file/bot"
|
||||||
|
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
self.application = (
|
||||||
|
ApplicationBuilder()
|
||||||
|
.token(self.config["telegram_token"])
|
||||||
|
.base_url(base_url)
|
||||||
|
.base_file_url(file_base_url)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
message_handler = TelegramMessageHandler(
|
||||||
|
filters=filters.ALL, # receive all messages
|
||||||
|
callback=self.message_handler,
|
||||||
|
)
|
||||||
|
self.application.add_handler(message_handler)
|
||||||
|
self.client = self.application.bot
|
||||||
|
logger.debug(f"Telegram base url: {self.client.base_url}")
|
||||||
|
|
||||||
|
self.scheduler = AsyncIOScheduler()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
from_username = session.session_id
|
||||||
|
await TelegramPlatformEvent.send_with_client(
|
||||||
|
self.client, message_chain, from_username
|
||||||
|
)
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def run(self):
|
||||||
|
await self.application.initialize()
|
||||||
|
await self.application.start()
|
||||||
|
await self.register_commands()
|
||||||
|
|
||||||
|
# TODO 使用更优雅的方式重新注册命令
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self.register_commands,
|
||||||
|
"interval",
|
||||||
|
minutes=5,
|
||||||
|
id="telegram_command_register",
|
||||||
|
misfire_grace_time=60,
|
||||||
|
)
|
||||||
|
self.scheduler.start()
|
||||||
|
|
||||||
|
queue = self.application.updater.start_polling()
|
||||||
|
logger.info("Telegram Platform Adapter is running.")
|
||||||
|
await queue
|
||||||
|
|
||||||
|
async def register_commands(self):
|
||||||
|
"""收集所有注册的指令并注册到 Telegram"""
|
||||||
|
try:
|
||||||
|
await self.client.delete_my_commands()
|
||||||
|
commands = self.collect_commands()
|
||||||
|
|
||||||
|
if commands:
|
||||||
|
await self.client.set_my_commands(commands)
|
||||||
|
for cmd in commands:
|
||||||
|
logger.debug(f"已注册指令: /{cmd.command} - {cmd.description}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
|
||||||
|
|
||||||
|
def collect_commands(self) -> list[BotCommand]:
|
||||||
|
"""从注册的处理器中收集所有指令"""
|
||||||
|
command_dict = {}
|
||||||
|
skip_commands = {"start"}
|
||||||
|
|
||||||
|
for handler_md in star_handlers_registry._handlers:
|
||||||
|
handler_metadata = handler_md[1]
|
||||||
|
if not star_map[handler_metadata.handler_module_path].activated:
|
||||||
|
continue
|
||||||
|
for event_filter in handler_metadata.event_filters:
|
||||||
|
cmd_info = self._extract_command_info(
|
||||||
|
event_filter, handler_metadata, skip_commands
|
||||||
|
)
|
||||||
|
if cmd_info:
|
||||||
|
cmd_name, description = cmd_info
|
||||||
|
command_dict.setdefault(cmd_name, description)
|
||||||
|
|
||||||
|
commands_a = sorted(command_dict.keys())
|
||||||
|
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_command_info(
|
||||||
|
event_filter, handler_metadata, skip_commands: set
|
||||||
|
) -> tuple[str, str] | None:
|
||||||
|
"""从事件过滤器中提取指令信息"""
|
||||||
|
cmd_name = None
|
||||||
|
is_group = False
|
||||||
|
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
|
||||||
|
if (
|
||||||
|
event_filter.parent_command_names
|
||||||
|
and event_filter.parent_command_names != [""]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
cmd_name = event_filter.command_name
|
||||||
|
elif isinstance(event_filter, CommandGroupFilter):
|
||||||
|
if event_filter.parent_group:
|
||||||
|
return None
|
||||||
|
cmd_name = event_filter.group_name
|
||||||
|
is_group = True
|
||||||
|
|
||||||
|
if not cmd_name or cmd_name in skip_commands:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build description.
|
||||||
|
description = handler_metadata.desc or (
|
||||||
|
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
|
||||||
|
)
|
||||||
|
if len(description) > 30:
|
||||||
|
description = description[:30] + "..."
|
||||||
|
return cmd_name, description
|
||||||
|
|
||||||
|
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
await context.bot.send_message(
|
||||||
|
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
logger.debug(f"Telegram message: {update.message}")
|
||||||
|
abm = await self.convert_message(update, context)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
async def convert_message(
|
||||||
|
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||||
|
|
||||||
|
@param update: Telegram 的 Update 对象。
|
||||||
|
@param context: Telegram 的 Context 对象。
|
||||||
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
|
"""
|
||||||
|
message = AstrBotMessage()
|
||||||
|
message.session_id = str(update.message.chat.id)
|
||||||
|
# 获得是群聊还是私聊
|
||||||
|
if update.message.chat.type == ChatType.PRIVATE:
|
||||||
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
|
else:
|
||||||
|
message.type = MessageType.GROUP_MESSAGE
|
||||||
|
message.group_id = str(update.message.chat.id)
|
||||||
|
if update.message.message_thread_id:
|
||||||
|
# Topic Group
|
||||||
|
message.group_id += "#" + str(update.message.message_thread_id)
|
||||||
|
message.session_id = message.group_id
|
||||||
|
|
||||||
|
message.message_id = str(update.message.message_id)
|
||||||
|
message.sender = MessageMember(
|
||||||
|
str(update.message.from_user.id), update.message.from_user.username
|
||||||
|
)
|
||||||
|
message.self_id = str(context.bot.username)
|
||||||
|
message.raw_message = update
|
||||||
|
message.message_str = ""
|
||||||
|
message.message = []
|
||||||
|
|
||||||
|
if update.message.reply_to_message and not (
|
||||||
|
update.message.is_topic_message
|
||||||
|
and update.message.message_thread_id
|
||||||
|
== update.message.reply_to_message.message_id
|
||||||
|
):
|
||||||
|
# 获取回复消息
|
||||||
|
reply_update = Update(
|
||||||
|
update_id=1,
|
||||||
|
message=update.message.reply_to_message,
|
||||||
|
)
|
||||||
|
reply_abm = await self.convert_message(reply_update, context, False)
|
||||||
|
|
||||||
|
message.message.append(
|
||||||
|
Comp.Reply(
|
||||||
|
id=reply_abm.message_id,
|
||||||
|
chain=reply_abm.message,
|
||||||
|
sender_id=reply_abm.sender.user_id,
|
||||||
|
sender_nickname=reply_abm.sender.nickname,
|
||||||
|
time=reply_abm.timestamp,
|
||||||
|
message_str=reply_abm.message_str,
|
||||||
|
text=reply_abm.message_str,
|
||||||
|
qq=reply_abm.sender.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if update.message.text:
|
||||||
|
# 处理文本消息
|
||||||
|
plain_text = update.message.text
|
||||||
|
|
||||||
|
# 群聊场景命令特殊处理
|
||||||
|
if plain_text.startswith("/"):
|
||||||
|
command_parts = plain_text.split(" ", 1)
|
||||||
|
if "@" in command_parts[0]:
|
||||||
|
command, bot_name = command_parts[0].split("@")
|
||||||
|
if bot_name == self.client.username:
|
||||||
|
plain_text = command + (
|
||||||
|
f" {command_parts[1]}" if len(command_parts) > 1 else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if update.message.entities:
|
||||||
|
for entity in update.message.entities:
|
||||||
|
if entity.type == "mention":
|
||||||
|
name = plain_text[
|
||||||
|
entity.offset + 1 : entity.offset + entity.length
|
||||||
|
]
|
||||||
|
message.message.append(Comp.At(qq=name, name=name))
|
||||||
|
plain_text = (
|
||||||
|
plain_text[: entity.offset]
|
||||||
|
+ plain_text[entity.offset + entity.length :]
|
||||||
|
)
|
||||||
|
|
||||||
|
if plain_text:
|
||||||
|
message.message.append(Comp.Plain(plain_text))
|
||||||
|
message.message_str = plain_text
|
||||||
|
|
||||||
|
if message.message_str.strip() == "/start":
|
||||||
|
await self.start(update, context)
|
||||||
|
return
|
||||||
|
|
||||||
|
elif update.message.voice:
|
||||||
|
file = await update.message.voice.get_file()
|
||||||
|
message.message = [
|
||||||
|
Comp.Record(file=file.file_path, url=file.file_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
elif update.message.photo:
|
||||||
|
photo = update.message.photo[-1] # get the largest photo
|
||||||
|
file = await photo.get_file()
|
||||||
|
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||||
|
if update.message.caption:
|
||||||
|
message.message_str = update.message.caption
|
||||||
|
message.message.append(Comp.Plain(message.message_str))
|
||||||
|
if update.message.caption_entities:
|
||||||
|
for entity in update.message.caption_entities:
|
||||||
|
if entity.type == "mention":
|
||||||
|
name = message.message_str[
|
||||||
|
entity.offset + 1 : entity.offset + entity.length
|
||||||
|
]
|
||||||
|
message.message.append(Comp.At(qq=name, name=name))
|
||||||
|
|
||||||
|
elif update.message.sticker:
|
||||||
|
# 将sticker当作图片处理
|
||||||
|
file = await update.message.sticker.get_file()
|
||||||
|
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||||
|
if update.message.sticker.emoji:
|
||||||
|
sticker_text = f"Sticker: {update.message.sticker.emoji}"
|
||||||
|
message.message_str = sticker_text
|
||||||
|
message.message.append(Comp.Plain(sticker_text))
|
||||||
|
|
||||||
|
elif update.message.document:
|
||||||
|
file = await update.message.document.get_file()
|
||||||
|
message.message = [
|
||||||
|
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||||
|
]
|
||||||
|
|
||||||
|
elif update.message.video:
|
||||||
|
file = await update.message.video.get_file()
|
||||||
|
message.message = [
|
||||||
|
Comp.Video(file=file.file_path, path=file.file_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
message_event = TelegramPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.client,
|
||||||
|
)
|
||||||
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
def get_client(self) -> ExtBot:
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
try:
|
||||||
|
if self.scheduler.running:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
|
||||||
|
await self.application.stop()
|
||||||
|
await self.client.delete_my_commands()
|
||||||
|
|
||||||
|
# 保险起见先判断是否存在updater对象
|
||||||
|
if self.application.updater is not None:
|
||||||
|
await self.application.updater.stop()
|
||||||
|
|
||||||
|
logger.info("Telegram 适配器已被优雅地关闭")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Telegram 适配器关闭时出错: {e}")
|
||||||
198
astrbot/core/platform/sources/telegram/tg_event.py
Normal file
198
astrbot/core/platform/sources/telegram/tg_event.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
import asyncio
|
||||||
|
import telegramify_markdown
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||||
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
Reply,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Record,
|
||||||
|
)
|
||||||
|
from telegram.ext import ExtBot
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramPlatformEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
client: ExtBot,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
|
||||||
|
image_path = None
|
||||||
|
|
||||||
|
has_reply = False
|
||||||
|
reply_message_id = None
|
||||||
|
at_user_id = None
|
||||||
|
for i in message.chain:
|
||||||
|
if isinstance(i, Reply):
|
||||||
|
has_reply = True
|
||||||
|
reply_message_id = i.id
|
||||||
|
if isinstance(i, At):
|
||||||
|
at_user_id = i.name
|
||||||
|
|
||||||
|
at_flag = False
|
||||||
|
message_thread_id = None
|
||||||
|
if "#" in user_name:
|
||||||
|
# it's a supergroup chat with message_thread_id
|
||||||
|
user_name, message_thread_id = user_name.split("#")
|
||||||
|
for i in message.chain:
|
||||||
|
payload = {
|
||||||
|
"chat_id": user_name,
|
||||||
|
}
|
||||||
|
if has_reply:
|
||||||
|
payload["reply_to_message_id"] = reply_message_id
|
||||||
|
if message_thread_id:
|
||||||
|
payload["message_thread_id"] = message_thread_id
|
||||||
|
|
||||||
|
if isinstance(i, Plain):
|
||||||
|
if at_user_id and not at_flag:
|
||||||
|
i.text = f"@{at_user_id} " + i.text
|
||||||
|
at_flag = True
|
||||||
|
text = i.text
|
||||||
|
try:
|
||||||
|
text = telegramify_markdown.markdownify(
|
||||||
|
i.text, max_line_length=None, normalize_whitespace=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||||
|
elif isinstance(i, Image):
|
||||||
|
image_path = await i.convert_to_file_path()
|
||||||
|
await client.send_photo(photo=image_path, **payload)
|
||||||
|
elif isinstance(i, File):
|
||||||
|
if i.file.startswith("https://"):
|
||||||
|
path = "data/temp/" + i.name
|
||||||
|
await download_file(i.file, path)
|
||||||
|
i.file = path
|
||||||
|
|
||||||
|
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||||
|
elif isinstance(i, Record):
|
||||||
|
path = await i.convert_to_file_path()
|
||||||
|
await client.send_voice(voice=path, **payload)
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
|
await self.send_with_client(self.client, message, self.message_obj.group_id)
|
||||||
|
else:
|
||||||
|
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
message_thread_id = None
|
||||||
|
|
||||||
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
|
user_name = self.message_obj.group_id
|
||||||
|
else:
|
||||||
|
user_name = self.get_sender_id()
|
||||||
|
|
||||||
|
if "#" in user_name:
|
||||||
|
# it's a supergroup chat with message_thread_id
|
||||||
|
user_name, message_thread_id = user_name.split("#")
|
||||||
|
payload = {
|
||||||
|
"chat_id": user_name,
|
||||||
|
}
|
||||||
|
if message_thread_id:
|
||||||
|
payload["reply_to_message_id"] = message_thread_id
|
||||||
|
|
||||||
|
delta = ""
|
||||||
|
current_content = ""
|
||||||
|
message_id = None
|
||||||
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
|
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
# 处理消息链中的每个组件
|
||||||
|
for i in chain.chain:
|
||||||
|
if isinstance(i, Plain):
|
||||||
|
delta += i.text
|
||||||
|
elif isinstance(i, Image):
|
||||||
|
image_path = await i.convert_to_file_path()
|
||||||
|
await self.client.send_photo(photo=image_path, **payload)
|
||||||
|
continue
|
||||||
|
elif isinstance(i, File):
|
||||||
|
if i.file.startswith("https://"):
|
||||||
|
path = "data/temp/" + i.name
|
||||||
|
await download_file(i.file, path)
|
||||||
|
i.file = path
|
||||||
|
|
||||||
|
await self.client.send_document(
|
||||||
|
document=i.file, filename=i.name, **payload
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif isinstance(i, Record):
|
||||||
|
path = await i.convert_to_file_path()
|
||||||
|
await self.client.send_voice(voice=path, **payload)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Plain
|
||||||
|
if not message_id:
|
||||||
|
try:
|
||||||
|
msg = await self.client.send_message(text=delta, **payload)
|
||||||
|
current_content = delta
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
|
message_id = msg.message_id
|
||||||
|
last_edit_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
) # 记录初始消息发送时间
|
||||||
|
else:
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
|
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
|
||||||
|
if time_since_last_edit >= throttle_interval:
|
||||||
|
# 编辑消息
|
||||||
|
try:
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
current_content = delta
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
|
last_edit_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
) # 更新上次编辑的时间
|
||||||
|
|
||||||
|
try:
|
||||||
|
if delta and current_content != delta:
|
||||||
|
try:
|
||||||
|
markdown_text = telegramify_markdown.markdownify(
|
||||||
|
delta, max_line_length=None, normalize_whitespace=False
|
||||||
|
)
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=markdown_text,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
parse_mode="MarkdownV2"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
|
|
||||||
|
return await super().send_streaming(generator)
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
import random
|
|
||||||
import asyncio
|
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
|
||||||
from astrbot.api.message_components import Plain, Image
|
|
||||||
from vchat import Core
|
|
||||||
|
|
||||||
class VChatPlatformEvent(AstrMessageEvent):
|
|
||||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
||||||
self.client = client
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def send_with_client(client: Core, message: MessageChain, user_name: str):
|
|
||||||
plain = ""
|
|
||||||
for comp in message.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
if message.is_split_:
|
|
||||||
await client.send_msg(comp.text, user_name)
|
|
||||||
else:
|
|
||||||
plain += comp.text
|
|
||||||
elif isinstance(comp, Image):
|
|
||||||
if comp.file and comp.file.startswith("file:///"):
|
|
||||||
file_path = comp.file.replace("file:///", "")
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
await client.send_image(user_name, fd=f)
|
|
||||||
elif comp.file and comp.file.startswith("http"):
|
|
||||||
image_path = await download_image_by_url(comp.file)
|
|
||||||
with open(image_path, "rb") as f:
|
|
||||||
await client.send_image(user_name, fd=f)
|
|
||||||
else:
|
|
||||||
logger.error(f"不支持的 vchat(微信适配器) 消息类型: {comp}")
|
|
||||||
await asyncio.sleep(random.uniform(0.5, 1.5)) # 🤓
|
|
||||||
|
|
||||||
if plain:
|
|
||||||
await client.send_msg(plain, user_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
|
|
||||||
await super().send(message)
|
|
||||||
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
import sys
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
|
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.api.message_components import *
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
|
||||||
from .vchat_message_event import VChatPlatformEvent
|
|
||||||
from ...register import register_platform_adapter
|
|
||||||
|
|
||||||
from vchat import Core
|
|
||||||
from vchat import model
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
|
||||||
from typing import override
|
|
||||||
else:
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器")
|
|
||||||
class VChatPlatformAdapter(Platform):
|
|
||||||
|
|
||||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
|
||||||
super().__init__(event_queue)
|
|
||||||
self.config = platform_config
|
|
||||||
self.settingss = platform_settings
|
|
||||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
|
||||||
self.client_self_id = uuid.uuid4().hex[:8]
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
|
||||||
from_username = session.session_id.split('$$')[0]
|
|
||||||
await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username)
|
|
||||||
await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
return PlatformMetadata(
|
|
||||||
"vchat",
|
|
||||||
"基于 VChat 的 Wechat 适配器",
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def run(self):
|
|
||||||
self.client = Core()
|
|
||||||
@self.client.msg_register(msg_types=model.ContentTypes.TEXT,
|
|
||||||
contact_type=model.ContactTypes.CHATROOM | model.ContactTypes.USER)
|
|
||||||
async def _(msg: model.Message):
|
|
||||||
if isinstance(msg.content, model.UselessContent):
|
|
||||||
return
|
|
||||||
if msg.create_time < self.start_time:
|
|
||||||
logger.debug(f"忽略旧消息: {msg}")
|
|
||||||
return
|
|
||||||
logger.debug(f"收到消息: {msg.todict()}")
|
|
||||||
abmsg = self.convert_message(msg)
|
|
||||||
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
|
|
||||||
asyncio.create_task(self.handle_msg(abmsg))
|
|
||||||
|
|
||||||
# TODO: 对齐微信服务器时间
|
|
||||||
self.start_time = int(time.time())
|
|
||||||
return self._run()
|
|
||||||
|
|
||||||
|
|
||||||
async def _run(self):
|
|
||||||
await self.client.init()
|
|
||||||
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
|
|
||||||
await self.client.run()
|
|
||||||
|
|
||||||
def convert_message(self, msg: model.Message) -> AstrBotMessage:
|
|
||||||
# credits: https://github.com/z2z63/astrbot_plugin_vchat/blob/master/main.py#L49
|
|
||||||
assert isinstance(msg.content, model.TextContent)
|
|
||||||
amsg = AstrBotMessage()
|
|
||||||
amsg.message = [Plain(msg.content.content)]
|
|
||||||
amsg.self_id = self.client_self_id
|
|
||||||
if msg.content.is_at_me:
|
|
||||||
amsg.message.insert(0, At(qq=amsg.self_id))
|
|
||||||
|
|
||||||
sender = msg.chatroom_sender or msg.from_
|
|
||||||
amsg.sender = MessageMember(sender.username, sender.nickname)
|
|
||||||
|
|
||||||
if msg.content.is_at_me:
|
|
||||||
amsg.message_str = msg.content.content.split("\u2005")[1].strip()
|
|
||||||
else:
|
|
||||||
amsg.message_str = msg.content.content
|
|
||||||
amsg.message_id = msg.message_id
|
|
||||||
if isinstance(msg.from_, model.User):
|
|
||||||
amsg.type = MessageType.FRIEND_MESSAGE
|
|
||||||
elif isinstance(msg.from_, model.Chatroom):
|
|
||||||
amsg.type = MessageType.GROUP_MESSAGE
|
|
||||||
amsg.group_id = msg.from_.username
|
|
||||||
else:
|
|
||||||
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
|
|
||||||
|
|
||||||
amsg.raw_message = msg
|
|
||||||
|
|
||||||
if self.settingss['unique_session']:
|
|
||||||
session_id = msg.from_.username + "$$" + msg.to.username
|
|
||||||
if msg.chatroom_sender is not None:
|
|
||||||
session_id += '$$' + msg.chatroom_sender.username
|
|
||||||
else:
|
|
||||||
session_id = msg.from_.username
|
|
||||||
|
|
||||||
amsg.session_id = session_id
|
|
||||||
return amsg
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
|
||||||
message_event = VChatPlatformEvent(
|
|
||||||
message_str=message.message_str,
|
|
||||||
message_obj=message,
|
|
||||||
platform_meta=self.meta(),
|
|
||||||
session_id=message.session_id,
|
|
||||||
client=self.client
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"处理消息: {message_event}")
|
|
||||||
|
|
||||||
self.commit_event(message_event)
|
|
||||||
@@ -3,15 +3,22 @@ import asyncio
|
|||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any
|
||||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
from astrbot.core.platform import (
|
||||||
from astrbot.api.event import MessageChain
|
Platform,
|
||||||
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
|
AstrBotMessage,
|
||||||
from astrbot.api import logger
|
MessageMember,
|
||||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core import web_chat_queue
|
||||||
from .webchat_event import WebChatMessageEvent
|
from .webchat_event import WebChatMessageEvent
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
|
|
||||||
|
|
||||||
class QueueListener:
|
class QueueListener:
|
||||||
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
|
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
@@ -22,36 +29,32 @@ class QueueListener:
|
|||||||
data = await self.queue.get()
|
data = await self.queue.get()
|
||||||
await self.callback(data)
|
await self.callback(data)
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("webchat", "webchat")
|
@register_platform_adapter("webchat", "webchat")
|
||||||
class WebChatAdapter(Platform):
|
class WebChatAdapter(Platform):
|
||||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
super().__init__(event_queue)
|
super().__init__(event_queue)
|
||||||
|
|
||||||
self.config = platform_config
|
self.config = platform_config
|
||||||
self.settings = platform_settings
|
self.settings = platform_settings
|
||||||
self.unique_session = platform_settings['unique_session']
|
self.unique_session = platform_settings["unique_session"]
|
||||||
self.imgs_dir = "data/webchat/imgs"
|
self.imgs_dir = "data/webchat/imgs"
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
"webchat",
|
name="webchat", description="webchat", id=self.config.get("id")
|
||||||
"webchat",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
async def send_by_session(
|
||||||
# abm.session_id = f"webchat!{username}!{cid}"
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
plain = ""
|
):
|
||||||
cid = session.session_id.split("!")[-1]
|
await WebChatMessageEvent._send(message_chain, session.session_id)
|
||||||
for comp in message_chain.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
plain += comp.text
|
|
||||||
web_chat_back_queue.put_nowait((plain, cid))
|
|
||||||
|
|
||||||
await super().send_by_session(session, message_chain)
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||||
username, cid, payload = data
|
username, cid, payload = data
|
||||||
|
|
||||||
|
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = "webchat"
|
abm.self_id = "webchat"
|
||||||
abm.tag = "webchat"
|
abm.tag = "webchat"
|
||||||
@@ -64,26 +67,32 @@ class WebChatAdapter(Platform):
|
|||||||
abm.message_id = str(uuid.uuid4())
|
abm.message_id = str(uuid.uuid4())
|
||||||
abm.message = []
|
abm.message = []
|
||||||
|
|
||||||
if payload['message']:
|
if payload["message"]:
|
||||||
abm.message.append(Plain(payload['message']))
|
abm.message.append(Plain(payload["message"]))
|
||||||
if payload['image_url']:
|
if payload["image_url"]:
|
||||||
if isinstance(payload['image_url'], list):
|
if isinstance(payload["image_url"], list):
|
||||||
for img in payload['image_url']:
|
for img in payload["image_url"]:
|
||||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
|
abm.message.append(
|
||||||
|
Image.fromFileSystem(os.path.join(self.imgs_dir, img))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
|
abm.message.append(
|
||||||
if payload['audio_url']:
|
Image.fromFileSystem(
|
||||||
if isinstance(payload['audio_url'], list):
|
os.path.join(self.imgs_dir, payload["image_url"])
|
||||||
for audio in payload['audio_url']:
|
)
|
||||||
|
)
|
||||||
|
if payload["audio_url"]:
|
||||||
|
if isinstance(payload["audio_url"], list):
|
||||||
|
for audio in payload["audio_url"]:
|
||||||
path = os.path.join(self.imgs_dir, audio)
|
path = os.path.join(self.imgs_dir, audio)
|
||||||
abm.message.append(Record(file=path, path=path))
|
abm.message.append(Record(file=path, path=path))
|
||||||
else:
|
else:
|
||||||
path = os.path.join(self.imgs_dir, payload['audio_url'])
|
path = os.path.join(self.imgs_dir, payload["audio_url"])
|
||||||
abm.message.append(Record(file=path, path=path))
|
abm.message.append(Record(file=path, path=path))
|
||||||
|
|
||||||
logger.debug(f"WebChatAdapter: {abm.message}")
|
logger.debug(f"WebChatAdapter: {abm.message}")
|
||||||
|
|
||||||
message_str = payload['message']
|
message_str = payload["message"]
|
||||||
abm.timestamp = int(time.time())
|
abm.timestamp = int(time.time())
|
||||||
abm.message_str = message_str
|
abm.message_str = message_str
|
||||||
abm.raw_message = data
|
abm.raw_message = data
|
||||||
@@ -101,12 +110,15 @@ class WebChatAdapter(Platform):
|
|||||||
return self.metadata
|
return self.metadata
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
|
||||||
message_event = WebChatMessageEvent(
|
message_event = WebChatMessageEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
message_obj=message,
|
message_obj=message,
|
||||||
platform_meta=self.meta(),
|
platform_meta=self.meta(),
|
||||||
session_id=message.session_id
|
session_id=message.session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
# Do nothing
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,30 +1,74 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import base64
|
||||||
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core import web_chat_back_queue
|
from astrbot.core import web_chat_back_queue
|
||||||
|
|
||||||
|
imgs_dir = "data/webchat/imgs"
|
||||||
|
|
||||||
|
|
||||||
class WebChatMessageEvent(AstrMessageEvent):
|
class WebChatMessageEvent(AstrMessageEvent):
|
||||||
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.imgs_dir = "data/webchat/imgs"
|
os.makedirs(imgs_dir, exist_ok=True)
|
||||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
@staticmethod
|
||||||
|
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||||
if not message:
|
if not message:
|
||||||
web_chat_back_queue.put_nowait(None)
|
await web_chat_back_queue.put(
|
||||||
|
{"type": "end", "data": "", "streaming": False}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
cid = self.session_id.split("!")[-1]
|
cid = session_id.split("!")[-1]
|
||||||
|
data = ""
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
data = comp.text
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "plain",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
# save image to local
|
# save image to local
|
||||||
filename = str(uuid.uuid4()) + ".jpg"
|
filename = str(uuid.uuid4()) + ".jpg"
|
||||||
path = os.path.join(self.imgs_dir, filename)
|
path = os.path.join(imgs_dir, filename)
|
||||||
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
|
ph = comp.file[8:]
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
with open(ph, "rb") as f2:
|
||||||
|
f.write(f2.read())
|
||||||
|
elif comp.file.startswith("base64://"):
|
||||||
|
base64_str = comp.file[9:]
|
||||||
|
image_data = base64.b64decode(base64_str)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
elif comp.file and comp.file.startswith("http"):
|
||||||
|
await download_image_by_url(comp.file, path=path)
|
||||||
|
else:
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
with open(comp.file, "rb") as f2:
|
||||||
|
f.write(f2.read())
|
||||||
|
data = f"[IMAGE]{filename}"
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Record):
|
||||||
|
# save record to local
|
||||||
|
filename = str(uuid.uuid4()) + ".wav"
|
||||||
|
path = os.path.join(imgs_dir, filename)
|
||||||
if comp.file and comp.file.startswith("file:///"):
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
ph = comp.file[8:]
|
ph = comp.file[8:]
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
@@ -36,6 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
with open(comp.file, "rb") as f2:
|
with open(comp.file, "rb") as f2:
|
||||||
f.write(f2.read())
|
f.write(f2.read())
|
||||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
data = f"[RECORD]{filename}"
|
||||||
web_chat_back_queue.put_nowait(None)
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "record",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"webchat 忽略: {comp.type}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": "",
|
||||||
|
"streaming": False,
|
||||||
|
"cid": self.session_id.split("!")[-1],
|
||||||
|
}
|
||||||
|
)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
final_data = ""
|
||||||
|
async for chain in generator:
|
||||||
|
final_data += await WebChatMessageEvent._send(
|
||||||
|
chain, session_id=self.session_id, streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": final_data,
|
||||||
|
"streaming": True,
|
||||||
|
"cid": self.session_id.split("!")[-1],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await super().send_streaming(generator)
|
||||||
|
|||||||
241
astrbot/core/platform/sources/wecom/wecom_adapter.py
Normal file
241
astrbot/core/platform/sources/wecom/wecom_adapter.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
import quart
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
PlatformMetadata,
|
||||||
|
MessageType,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from astrbot.api.platform import register_platform_adapter
|
||||||
|
from astrbot.core import logger
|
||||||
|
from requests import Response
|
||||||
|
|
||||||
|
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||||
|
from wechatpy.enterprise import WeChatClient
|
||||||
|
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
||||||
|
from wechatpy.exceptions import InvalidSignatureException
|
||||||
|
from wechatpy.enterprise import parse_message
|
||||||
|
from .wecom_event import WecomPlatformEvent
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import override
|
||||||
|
else:
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class WecomServer:
|
||||||
|
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||||
|
self.server = quart.Quart(__name__)
|
||||||
|
self.port = int(config.get("port"))
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
|
self.server.add_url_rule(
|
||||||
|
"/callback/command", view_func=self.verify, methods=["GET"]
|
||||||
|
)
|
||||||
|
self.server.add_url_rule(
|
||||||
|
"/callback/command", view_func=self.callback_command, methods=["POST"]
|
||||||
|
)
|
||||||
|
self.event_queue = event_queue
|
||||||
|
|
||||||
|
self.crypto = WeChatCrypto(
|
||||||
|
config["token"].strip(),
|
||||||
|
config["encoding_aes_key"].strip(),
|
||||||
|
config["corpid"].strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = None
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def verify(self):
|
||||||
|
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||||
|
args = quart.request.args
|
||||||
|
try:
|
||||||
|
echo_str = self.crypto.check_signature(
|
||||||
|
args.get("msg_signature"),
|
||||||
|
args.get("timestamp"),
|
||||||
|
args.get("nonce"),
|
||||||
|
args.get("echostr"),
|
||||||
|
)
|
||||||
|
logger.info("验证请求有效性成功。")
|
||||||
|
return echo_str
|
||||||
|
except InvalidSignatureException:
|
||||||
|
logger.error("验证请求有效性失败,签名异常,请检查配置。")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def callback_command(self):
|
||||||
|
data = await quart.request.get_data()
|
||||||
|
msg_signature = quart.request.args.get("msg_signature")
|
||||||
|
timestamp = quart.request.args.get("timestamp")
|
||||||
|
nonce = quart.request.args.get("nonce")
|
||||||
|
try:
|
||||||
|
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||||
|
except InvalidSignatureException:
|
||||||
|
logger.error("解密失败,签名异常,请检查配置。")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
msg = parse_message(xml)
|
||||||
|
logger.info(f"解析成功: {msg}")
|
||||||
|
|
||||||
|
if self.callback:
|
||||||
|
await self.callback(msg)
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
async def start_polling(self):
|
||||||
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。"
|
||||||
|
)
|
||||||
|
await self.server.run_task(
|
||||||
|
host=self.callback_server_host,
|
||||||
|
port=self.port,
|
||||||
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown_trigger(self):
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("wecom", "wecom 适配器")
|
||||||
|
class WecomPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self.config = platform_config
|
||||||
|
self.settingss = platform_settings
|
||||||
|
self.client_self_id = uuid.uuid4().hex[:8]
|
||||||
|
self.api_base_url = platform_config.get(
|
||||||
|
"api_base_url", "https://qyapi.weixin.qq.com/cgi-bin/"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.api_base_url:
|
||||||
|
self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/"
|
||||||
|
|
||||||
|
if self.api_base_url.endswith("/"):
|
||||||
|
self.api_base_url = self.api_base_url[:-1]
|
||||||
|
if not self.api_base_url.endswith("/cgi-bin"):
|
||||||
|
self.api_base_url += "/cgi-bin"
|
||||||
|
|
||||||
|
if not self.api_base_url.endswith("/"):
|
||||||
|
self.api_base_url += "/"
|
||||||
|
|
||||||
|
self.server = WecomServer(self._event_queue, self.config)
|
||||||
|
|
||||||
|
self.client = WeChatClient(
|
||||||
|
self.config["corpid"].strip(),
|
||||||
|
self.config["secret"].strip(),
|
||||||
|
)
|
||||||
|
self.client.API_BASE_URL = self.api_base_url
|
||||||
|
|
||||||
|
async def callback(msg):
|
||||||
|
await self.convert_message(msg)
|
||||||
|
|
||||||
|
self.server.callback = callback
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
"wecom",
|
||||||
|
"wecom 适配器",
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def run(self):
|
||||||
|
await self.server.start_polling()
|
||||||
|
|
||||||
|
async def convert_message(self, msg):
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
if msg.type == "text":
|
||||||
|
assert isinstance(msg, TextMessage)
|
||||||
|
abm.message_str = msg.content
|
||||||
|
abm.self_id = str(msg.agent)
|
||||||
|
abm.message = [Plain(msg.content)]
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
msg.source,
|
||||||
|
msg.source,
|
||||||
|
)
|
||||||
|
abm.message_id = msg.id
|
||||||
|
abm.timestamp = msg.time
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
abm.raw_message = msg
|
||||||
|
elif msg.type == "image":
|
||||||
|
assert isinstance(msg, ImageMessage)
|
||||||
|
abm.message_str = "[图片]"
|
||||||
|
abm.self_id = str(msg.agent)
|
||||||
|
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
msg.source,
|
||||||
|
msg.source,
|
||||||
|
)
|
||||||
|
abm.message_id = msg.id
|
||||||
|
abm.timestamp = msg.time
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
abm.raw_message = msg
|
||||||
|
elif msg.type == "voice":
|
||||||
|
assert isinstance(msg, VoiceMessage)
|
||||||
|
|
||||||
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.client.media.download, msg.media_id
|
||||||
|
)
|
||||||
|
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(resp.content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||||
|
audio = AudioSegment.from_file(path)
|
||||||
|
audio.export(path_wav, format="wav")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。")
|
||||||
|
path_wav = path
|
||||||
|
return
|
||||||
|
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.self_id = str(msg.agent)
|
||||||
|
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
msg.source,
|
||||||
|
msg.source,
|
||||||
|
)
|
||||||
|
abm.message_id = msg.id
|
||||||
|
abm.timestamp = msg.time
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
abm.raw_message = msg
|
||||||
|
|
||||||
|
logger.info(f"abm: {abm}")
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
message_event = WecomPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.client,
|
||||||
|
)
|
||||||
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
def get_client(self) -> WeChatClient:
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.server.shutdown_event.set()
|
||||||
|
await self.server.server.shutdown()
|
||||||
|
logger.info("企业微信 适配器已被优雅地关闭")
|
||||||
99
astrbot/core/platform/sources/wecom/wecom_event.py
Normal file
99
astrbot/core/platform/sources/wecom/wecom_event.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import uuid
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
|
from wechatpy.enterprise import WeChatClient
|
||||||
|
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pydub
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WecomPlatformEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
client: WeChatClient,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_with_client(
|
||||||
|
client: WeChatClient, message: MessageChain, user_name: str
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
message_obj = self.message_obj
|
||||||
|
|
||||||
|
for comp in message.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
self.client.message.send_text(
|
||||||
|
message_obj.self_id, message_obj.session_id, comp.text
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Image):
|
||||||
|
img_path = await comp.convert_to_file_path()
|
||||||
|
|
||||||
|
with open(img_path, "rb") as f:
|
||||||
|
try:
|
||||||
|
response = self.client.media.upload("image", f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"企业微信上传图片失败: {e}")
|
||||||
|
await self.send(
|
||||||
|
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(f"企业微信上传图片返回: {response}")
|
||||||
|
self.client.message.send_image(
|
||||||
|
message_obj.self_id,
|
||||||
|
message_obj.session_id,
|
||||||
|
response["media_id"],
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Record):
|
||||||
|
record_path = await comp.convert_to_file_path()
|
||||||
|
# 转成amr
|
||||||
|
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||||
|
pydub.AudioSegment.from_wav(record_path).export(
|
||||||
|
record_path_amr, format="amr"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(record_path_amr, "rb") as f:
|
||||||
|
try:
|
||||||
|
response = self.client.media.upload("voice", f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"企业微信上传语音失败: {e}")
|
||||||
|
await self.send(
|
||||||
|
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(f"企业微信上传语音返回: {response}")
|
||||||
|
self.client.message.send_voice(
|
||||||
|
message_obj.self_id,
|
||||||
|
message_obj.session_id,
|
||||||
|
response["media_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator)
|
||||||
@@ -1,10 +1,5 @@
|
|||||||
from .provider import Provider, Personality, STTProvider
|
from .provider import Provider, Personality, STTProvider
|
||||||
|
|
||||||
from .entites import ProviderMetaData
|
from .entities import ProviderMetaData
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
|
||||||
"Provider",
|
|
||||||
"Personality",
|
|
||||||
"ProviderMetaData",
|
|
||||||
"STTProvider"
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,54 +1,19 @@
|
|||||||
import enum
|
from astrbot.core.provider.entities import (
|
||||||
from dataclasses import dataclass, field
|
ProviderRequest,
|
||||||
from typing import List, Dict, Type
|
ProviderType,
|
||||||
from .func_tool_manager import FuncCall
|
ProviderMetaData,
|
||||||
|
ToolCallsResult,
|
||||||
|
AssistantMessageSegment,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
LLMResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
class ProviderType(enum.Enum):
|
"ProviderRequest",
|
||||||
CHAT_COMPLETION = "chat_completion"
|
"ProviderType",
|
||||||
SPEECH_TO_TEXT = "speech_to_text"
|
"ProviderMetaData",
|
||||||
TEXT_TO_SPEECH = "text_to_speech"
|
"ToolCallsResult",
|
||||||
|
"AssistantMessageSegment",
|
||||||
@dataclass
|
"ToolCallMessageSegment",
|
||||||
class ProviderMetaData():
|
"LLMResponse",
|
||||||
type: str
|
]
|
||||||
'''提供商适配器名称,如 openai, ollama'''
|
|
||||||
desc: str = ""
|
|
||||||
'''提供商适配器描述.'''
|
|
||||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
|
||||||
cls_type: Type = None
|
|
||||||
|
|
||||||
default_config_tmpl: dict = None
|
|
||||||
'''平台的默认配置模板'''
|
|
||||||
provider_display_name: str = None
|
|
||||||
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProviderRequest():
|
|
||||||
prompt: str
|
|
||||||
'''提示词'''
|
|
||||||
session_id: str = ""
|
|
||||||
'''会话 ID'''
|
|
||||||
image_urls: List[str] = None
|
|
||||||
'''图片 URL 列表'''
|
|
||||||
func_tool: FuncCall = None
|
|
||||||
'''工具'''
|
|
||||||
contexts: List = None
|
|
||||||
'''上下文。格式与 openai 的上下文格式一致:
|
|
||||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
|
||||||
'''
|
|
||||||
|
|
||||||
system_prompt: str = ""
|
|
||||||
'''系统提示词'''
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMResponse:
|
|
||||||
role: str
|
|
||||||
'''角色'''
|
|
||||||
completion_text: str = ""
|
|
||||||
'''LLM 返回的文本'''
|
|
||||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
|
||||||
'''工具调用参数'''
|
|
||||||
tools_call_name: List[str] = field(default_factory=list)
|
|
||||||
'''工具调用名称'''
|
|
||||||
|
|||||||
281
astrbot/core/provider/entities.py
Normal file
281
astrbot/core/provider/entities.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
import enum
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from astrbot import logger
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Type
|
||||||
|
from .func_tool_manager import FuncCall
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from astrbot.core.db.po import Conversation
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(enum.Enum):
|
||||||
|
CHAT_COMPLETION = "chat_completion"
|
||||||
|
SPEECH_TO_TEXT = "speech_to_text"
|
||||||
|
TEXT_TO_SPEECH = "text_to_speech"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderMetaData:
|
||||||
|
type: str
|
||||||
|
"""提供商适配器名称,如 openai, ollama"""
|
||||||
|
desc: str = ""
|
||||||
|
"""提供商适配器描述."""
|
||||||
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||||
|
cls_type: Type = None
|
||||||
|
|
||||||
|
default_config_tmpl: dict = None
|
||||||
|
"""平台的默认配置模板"""
|
||||||
|
provider_display_name: str = None
|
||||||
|
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallMessageSegment:
|
||||||
|
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
|
tool_call_id: str
|
||||||
|
content: str
|
||||||
|
role: str = "tool"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"tool_call_id": self.tool_call_id,
|
||||||
|
"content": self.content,
|
||||||
|
"role": self.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistantMessageSegment:
|
||||||
|
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
|
content: str = None
|
||||||
|
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||||
|
role: str = "assistant"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
ret = {
|
||||||
|
"role": self.role,
|
||||||
|
}
|
||||||
|
if self.content:
|
||||||
|
ret["content"] = self.content
|
||||||
|
elif self.tool_calls:
|
||||||
|
ret["tool_calls"] = self.tool_calls
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallsResult:
|
||||||
|
"""工具调用结果"""
|
||||||
|
|
||||||
|
tool_calls_info: AssistantMessageSegment
|
||||||
|
"""函数调用的信息"""
|
||||||
|
tool_calls_result: List[ToolCallMessageSegment]
|
||||||
|
"""函数调用的结果"""
|
||||||
|
|
||||||
|
def to_openai_messages(self) -> List[Dict]:
|
||||||
|
ret = [
|
||||||
|
self.tool_calls_info.to_dict(),
|
||||||
|
*[item.to_dict() for item in self.tool_calls_result],
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderRequest:
|
||||||
|
prompt: str
|
||||||
|
"""提示词"""
|
||||||
|
session_id: str = ""
|
||||||
|
"""会话 ID"""
|
||||||
|
image_urls: List[str] = None
|
||||||
|
"""图片 URL 列表"""
|
||||||
|
func_tool: FuncCall = None
|
||||||
|
"""可用的函数工具"""
|
||||||
|
contexts: List = None
|
||||||
|
"""上下文。格式与 openai 的上下文格式一致:
|
||||||
|
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||||
|
"""
|
||||||
|
system_prompt: str = ""
|
||||||
|
"""系统提示词"""
|
||||||
|
conversation: Conversation = None
|
||||||
|
|
||||||
|
tool_calls_result: ToolCallsResult = None
|
||||||
|
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def _print_friendly_context(self):
|
||||||
|
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||||
|
if not self.contexts:
|
||||||
|
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
|
||||||
|
|
||||||
|
result_parts = []
|
||||||
|
|
||||||
|
for ctx in self.contexts:
|
||||||
|
role = ctx.get("role", "unknown")
|
||||||
|
content = ctx.get("content", "")
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
result_parts.append(f"{role}: {content}")
|
||||||
|
elif isinstance(content, list):
|
||||||
|
msg_parts = []
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
item_type = item.get("type", "")
|
||||||
|
|
||||||
|
if item_type == "text":
|
||||||
|
msg_parts.append(item.get("text", ""))
|
||||||
|
elif item_type == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
|
||||||
|
if image_count > 0:
|
||||||
|
if msg_parts:
|
||||||
|
msg_parts.append(f"[+{image_count} images]")
|
||||||
|
else:
|
||||||
|
msg_parts.append(f"[{image_count} images]")
|
||||||
|
|
||||||
|
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
||||||
|
|
||||||
|
return result_parts
|
||||||
|
|
||||||
|
async def assemble_context(self) -> Dict:
|
||||||
|
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||||
|
if self.image_urls:
|
||||||
|
user_content = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": self.prompt}],
|
||||||
|
}
|
||||||
|
for image_url in self.image_urls:
|
||||||
|
if image_url.startswith("http"):
|
||||||
|
image_path = await download_image_by_url(image_url)
|
||||||
|
image_data = await self._encode_image_bs64(image_path)
|
||||||
|
elif image_url.startswith("file:///"):
|
||||||
|
image_path = image_url.replace("file:///", "")
|
||||||
|
image_data = await self._encode_image_bs64(image_path)
|
||||||
|
else:
|
||||||
|
image_data = await self._encode_image_bs64(image_url)
|
||||||
|
if not image_data:
|
||||||
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
|
continue
|
||||||
|
user_content["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_data}}
|
||||||
|
)
|
||||||
|
return user_content
|
||||||
|
else:
|
||||||
|
return {"role": "user", "content": self.prompt}
|
||||||
|
|
||||||
|
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||||
|
"""将图片转换为 base64"""
|
||||||
|
if image_url.startswith("base64://"):
|
||||||
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||||
|
with open(image_url, "rb") as f:
|
||||||
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
return "data:image/jpeg;base64," + image_bs64
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
role: str
|
||||||
|
"""角色, assistant, tool, err"""
|
||||||
|
result_chain: MessageChain = None
|
||||||
|
"""返回的消息链"""
|
||||||
|
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||||
|
"""工具调用参数"""
|
||||||
|
tools_call_name: List[str] = field(default_factory=list)
|
||||||
|
"""工具调用名称"""
|
||||||
|
tools_call_ids: List[str] = field(default_factory=list)
|
||||||
|
"""工具调用 ID"""
|
||||||
|
|
||||||
|
raw_completion: ChatCompletion = None
|
||||||
|
_new_record: Dict[str, any] = None
|
||||||
|
|
||||||
|
_completion_text: str = ""
|
||||||
|
|
||||||
|
is_chunk: bool = False
|
||||||
|
"""是否是流式输出的单个 Chunk"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: str,
|
||||||
|
completion_text: str = "",
|
||||||
|
result_chain: MessageChain = None,
|
||||||
|
tools_call_args: List[Dict[str, any]] = None,
|
||||||
|
tools_call_name: List[str] = None,
|
||||||
|
tools_call_ids: List[str] = None,
|
||||||
|
raw_completion: ChatCompletion = None,
|
||||||
|
_new_record: Dict[str, any] = None,
|
||||||
|
is_chunk: bool = False,
|
||||||
|
):
|
||||||
|
"""初始化 LLMResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): 角色, assistant, tool, err
|
||||||
|
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
|
||||||
|
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
|
||||||
|
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
|
||||||
|
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
||||||
|
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||||
|
"""
|
||||||
|
if tools_call_args is None:
|
||||||
|
tools_call_args = []
|
||||||
|
if tools_call_name is None:
|
||||||
|
tools_call_name = []
|
||||||
|
if tools_call_ids is None:
|
||||||
|
tools_call_ids = []
|
||||||
|
|
||||||
|
self.role = role
|
||||||
|
self.completion_text = completion_text
|
||||||
|
self.result_chain = result_chain
|
||||||
|
self.tools_call_args = tools_call_args
|
||||||
|
self.tools_call_name = tools_call_name
|
||||||
|
self.tools_call_ids = tools_call_ids
|
||||||
|
self.raw_completion = raw_completion
|
||||||
|
self._new_record = _new_record
|
||||||
|
self.is_chunk = is_chunk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_text(self):
|
||||||
|
if self.result_chain:
|
||||||
|
return self.result_chain.get_plain_text()
|
||||||
|
return self._completion_text
|
||||||
|
|
||||||
|
@completion_text.setter
|
||||||
|
def completion_text(self, value):
|
||||||
|
if self.result_chain:
|
||||||
|
self.result_chain.chain = [
|
||||||
|
comp
|
||||||
|
for comp in self.result_chain.chain
|
||||||
|
if not isinstance(comp, Comp.Plain)
|
||||||
|
] # 清空 Plain 组件
|
||||||
|
self.result_chain.chain.insert(0, Comp.Plain(value))
|
||||||
|
else:
|
||||||
|
self._completion_text = value
|
||||||
|
|
||||||
|
def to_openai_tool_calls(self) -> List[Dict]:
|
||||||
|
"""将工具调用信息转换为 OpenAI 格式"""
|
||||||
|
ret = []
|
||||||
|
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
||||||
|
ret.append(
|
||||||
|
{
|
||||||
|
"id": self.tools_call_ids[idx],
|
||||||
|
"function": {
|
||||||
|
"name": self.tools_call_name[idx],
|
||||||
|
"arguments": json.dumps(tool_call_arg),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ret
|
||||||
@@ -1,7 +1,30 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Dict, List, Awaitable
|
import os
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from typing import Dict, List, Awaitable, Literal, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mcp
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||||
|
|
||||||
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||||
|
|
||||||
|
SUPPORTED_TYPES = [
|
||||||
|
"string",
|
||||||
|
"number",
|
||||||
|
"object",
|
||||||
|
"array",
|
||||||
|
"boolean",
|
||||||
|
] # json schema 支持的数据类型
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -13,24 +36,101 @@ class FuncTool:
|
|||||||
name: str
|
name: str
|
||||||
parameters: Dict
|
parameters: Dict
|
||||||
description: str
|
description: str
|
||||||
handler: Awaitable
|
handler: Awaitable = None
|
||||||
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||||
|
handler_module_path: str = None
|
||||||
|
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||||
|
|
||||||
|
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||||
|
"""
|
||||||
active: bool = True
|
active: bool = True
|
||||||
'''是否激活'''
|
"""是否激活"""
|
||||||
|
|
||||||
SUPPORTED_TYPES = [
|
origin: Literal["local", "mcp"] = "local"
|
||||||
"string",
|
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||||
"number",
|
|
||||||
"object",
|
# MCP 相关字段
|
||||||
"array",
|
mcp_server_name: str = None
|
||||||
"boolean",
|
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||||
] # json schema 支持的数据类型
|
mcp_client: MCPClient = None
|
||||||
|
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||||
|
|
||||||
|
async def execute(self, **args) -> Any:
|
||||||
|
"""执行函数调用"""
|
||||||
|
if self.origin == "local":
|
||||||
|
if not self.handler:
|
||||||
|
raise Exception(f"Local function {self.name} has no handler")
|
||||||
|
return await self.handler(**args)
|
||||||
|
elif self.origin == "mcp":
|
||||||
|
if not self.mcp_client or not self.mcp_client.session:
|
||||||
|
raise Exception(f"MCP client for {self.name} is not available")
|
||||||
|
# 使用name属性而不是额外的mcp_tool_name
|
||||||
|
if ":" in self.name:
|
||||||
|
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||||
|
actual_tool_name = self.name.split(":")[-1]
|
||||||
|
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||||
|
else:
|
||||||
|
return await self.mcp_client.session.call_tool(self.name, args)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown function origin: {self.origin}")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize session and client objects
|
||||||
|
self.session: Optional[mcp.ClientSession] = None
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
self.name = None
|
||||||
|
self.active: bool = True
|
||||||
|
self.tools: List[mcp.Tool] = []
|
||||||
|
|
||||||
|
async def connect_to_server(self, mcp_server_config: dict):
|
||||||
|
"""Connect to an MCP server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||||
|
"""
|
||||||
|
cfg = mcp_server_config.copy()
|
||||||
|
cfg.pop("active", None)
|
||||||
|
server_params = mcp.StdioServerParameters(
|
||||||
|
**cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
stdio_transport = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.stdio_client(server_params)
|
||||||
|
)
|
||||||
|
self.stdio, self.write = stdio_transport
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.ClientSession(self.stdio, self.write)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.session.initialize()
|
||||||
|
|
||||||
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
|
"""List all tools from the server and save them to self.tools"""
|
||||||
|
response = await self.session.list_tools()
|
||||||
|
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||||
|
self.tools = response.tools
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""Clean up resources"""
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
|
||||||
|
|
||||||
class FuncCall:
|
class FuncCall:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.func_list: List[FuncTool] = []
|
self.func_list: List[FuncTool] = []
|
||||||
|
"""内部加载的 func tools"""
|
||||||
|
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||||
|
"""MCP 服务列表"""
|
||||||
|
self.mcp_service_queue = asyncio.Queue()
|
||||||
|
"""用于外部控制 MCP 服务的启停"""
|
||||||
|
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
return len(self.func_list) == 0
|
return len(self.func_list) == 0
|
||||||
@@ -42,14 +142,16 @@ class FuncCall:
|
|||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Awaitable,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""添加函数调用工具
|
||||||
为函数调用(function-calling / tools-use)添加工具。
|
|
||||||
|
|
||||||
@param name: 函数名
|
@param name: 函数名
|
||||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||||
@param desc: 函数描述
|
@param desc: 函数描述
|
||||||
@param func_obj: 处理函数
|
@param func_obj: 处理函数
|
||||||
"""
|
"""
|
||||||
|
# check if the tool has been added before
|
||||||
|
self.remove_func(name)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
"properties": {},
|
"properties": {},
|
||||||
@@ -66,13 +168,14 @@ class FuncCall:
|
|||||||
handler=handler,
|
handler=handler,
|
||||||
)
|
)
|
||||||
self.func_list.append(_func)
|
self.func_list.append(_func)
|
||||||
|
logger.info(f"添加函数调用工具: {name}")
|
||||||
|
|
||||||
def remove_func(self, name: str) -> None:
|
def remove_func(self, name: str) -> None:
|
||||||
"""
|
"""
|
||||||
删除一个函数调用工具。
|
删除一个函数调用工具。
|
||||||
"""
|
"""
|
||||||
for i, f in enumerate(self.func_list):
|
for i, f in enumerate(self.func_list):
|
||||||
if f["name"] == name:
|
if f.name == name:
|
||||||
self.func_list.pop(i)
|
self.func_list.pop(i)
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -82,43 +185,233 @@ class FuncCall:
|
|||||||
return f
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_func_desc_openai_style(self) -> list:
|
async def _init_mcp_clients(self) -> None:
|
||||||
|
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"weather": {
|
||||||
|
"command": "uv",
|
||||||
|
"args": [
|
||||||
|
"--directory",
|
||||||
|
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
|
||||||
|
"run",
|
||||||
|
"weather.py"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
|
||||||
|
|
||||||
|
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||||
|
if not os.path.exists(mcp_json_file):
|
||||||
|
# 配置文件不存在错误处理
|
||||||
|
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||||
|
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
mcp_server_json_obj: Dict[str, Dict] = json.load(
|
||||||
|
open(mcp_json_file, "r", encoding="utf-8")
|
||||||
|
)["mcpServers"]
|
||||||
|
|
||||||
|
for name in mcp_server_json_obj.keys():
|
||||||
|
cfg = mcp_server_json_obj[name]
|
||||||
|
if cfg.get("active", True):
|
||||||
|
event = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(name, cfg, event)
|
||||||
|
)
|
||||||
|
self.mcp_client_event[name] = event
|
||||||
|
|
||||||
|
async def mcp_service_selector(self):
|
||||||
|
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||||
|
|
||||||
|
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||||
|
|
||||||
|
{"type": "init"} 初始化所有MCP客户端
|
||||||
|
|
||||||
|
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||||
|
|
||||||
|
{"type": "terminate"} 终止所有MCP客户端
|
||||||
|
|
||||||
|
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
data = await self.mcp_service_queue.get()
|
||||||
|
if data["type"] == "init":
|
||||||
|
if "name" in data:
|
||||||
|
event = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(
|
||||||
|
data["name"], data["cfg"], event
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.mcp_client_event[data["name"]] = event
|
||||||
|
else:
|
||||||
|
await self._init_mcp_clients()
|
||||||
|
elif data["type"] == "terminate":
|
||||||
|
if "name" in data:
|
||||||
|
# await self._terminate_mcp_client(data["name"])
|
||||||
|
if data["name"] in self.mcp_client_event:
|
||||||
|
self.mcp_client_event[data["name"]].set()
|
||||||
|
self.mcp_client_event.pop(data["name"], None)
|
||||||
|
else:
|
||||||
|
for name in self.mcp_client_dict.keys():
|
||||||
|
# await self._terminate_mcp_client(name)
|
||||||
|
# self.mcp_client_event[name].set()
|
||||||
|
if name in self.mcp_client_event:
|
||||||
|
self.mcp_client_event[name].set()
|
||||||
|
self.mcp_client_event.pop(name, None)
|
||||||
|
|
||||||
|
async def _init_mcp_client_task_wrapper(
|
||||||
|
self, name: str, cfg: dict, event: asyncio.Event
|
||||||
|
) -> None:
|
||||||
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
|
try:
|
||||||
|
await self._init_mcp_client(name, cfg)
|
||||||
|
await event.wait()
|
||||||
|
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||||
|
|
||||||
|
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||||
|
"""初始化单个MCP客户端"""
|
||||||
|
try:
|
||||||
|
# 先清理之前的客户端,如果存在
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
|
||||||
|
mcp_client = MCPClient()
|
||||||
|
mcp_client.name = name
|
||||||
|
await mcp_client.connect_to_server(config)
|
||||||
|
tools_res = await mcp_client.list_tools_and_save()
|
||||||
|
tool_names = [tool.name for tool in tools_res.tools]
|
||||||
|
self.mcp_client_dict[name] = mcp_client
|
||||||
|
|
||||||
|
# 移除该MCP服务之前的工具(如有)
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||||
|
for tool in mcp_client.tools:
|
||||||
|
func_tool = FuncTool(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=tool.inputSchema,
|
||||||
|
description=tool.description,
|
||||||
|
origin="mcp",
|
||||||
|
mcp_server_name=name,
|
||||||
|
mcp_client=mcp_client,
|
||||||
|
)
|
||||||
|
self.func_list.append(func_tool)
|
||||||
|
|
||||||
|
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||||
|
# 发生错误时确保客户端被清理
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _terminate_mcp_client(self, name: str) -> None:
|
||||||
|
"""关闭并清理MCP客户端"""
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
try:
|
||||||
|
# 关闭MCP连接
|
||||||
|
await self.mcp_client_dict[name].cleanup()
|
||||||
|
del self.mcp_client_dict[name]
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||||
|
# 移除关联的FuncTool
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
logger.info(f"已关闭 MCP 服务 {name}")
|
||||||
|
|
||||||
|
def get_func_desc_openai_style(self, omit_empty_parameter_field = True) -> list:
|
||||||
"""
|
"""
|
||||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||||
"""
|
"""
|
||||||
_l = []
|
_l = []
|
||||||
|
# 处理所有工具(包括本地和MCP工具)
|
||||||
for f in self.func_list:
|
for f in self.func_list:
|
||||||
if not f.active:
|
if not f.active:
|
||||||
continue
|
continue
|
||||||
_l.append(
|
func_ = {
|
||||||
{
|
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": f.name,
|
"name": f.name,
|
||||||
"parameters": f.parameters,
|
# "parameters": f.parameters,
|
||||||
"description": f.description,
|
"description": f.description,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
func_["function"]["parameters"] = f.parameters
|
||||||
|
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||||
|
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||||
|
del func_["function"]["parameters"]
|
||||||
|
_l.append(func_)
|
||||||
return _l
|
return _l
|
||||||
|
|
||||||
|
def get_func_desc_anthropic_style(self) -> list:
|
||||||
|
"""
|
||||||
|
获得 Anthropic API 风格的**已经激活**的工具描述
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
for f in self.func_list:
|
||||||
|
if not f.active:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert internal format to Anthropic style
|
||||||
|
tool = {
|
||||||
|
"name": f.name,
|
||||||
|
"description": f.description,
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": f.parameters.get("properties", {}),
|
||||||
|
# Keep the required field from the original parameters if it exists
|
||||||
|
"required": f.parameters.get("required", []),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tools.append(tool)
|
||||||
|
return tools
|
||||||
|
|
||||||
def get_func_desc_google_genai_style(self) -> Dict:
|
def get_func_desc_google_genai_style(self) -> Dict:
|
||||||
declarations = {}
|
declarations = {}
|
||||||
tools = []
|
tools = []
|
||||||
for f in self.func_list:
|
for f in self.func_list:
|
||||||
if not f.active:
|
if not f.active:
|
||||||
continue
|
continue
|
||||||
tools.append(
|
|
||||||
{
|
func_declaration = {"name": f.name, "description": f.description}
|
||||||
"name": f.name,
|
|
||||||
"parameters": f.parameters,
|
# 检查并添加非空的properties参数
|
||||||
"description": f.description,
|
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||||
}
|
params = copy.deepcopy(params)
|
||||||
)
|
if params.get("properties", {}):
|
||||||
|
properties = params["properties"]
|
||||||
|
for key, value in properties.items():
|
||||||
|
if "default" in value:
|
||||||
|
del value["default"]
|
||||||
|
params["properties"] = properties
|
||||||
|
func_declaration["parameters"] = params
|
||||||
|
|
||||||
|
tools.append(func_declaration)
|
||||||
|
|
||||||
|
if tools:
|
||||||
declarations["function_declarations"] = tools
|
declarations["function_declarations"] = tools
|
||||||
return declarations
|
return declarations
|
||||||
|
|
||||||
|
|
||||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||||
_l = []
|
_l = []
|
||||||
for f in self.func_list:
|
for f in self.func_list:
|
||||||
@@ -126,9 +419,9 @@ class FuncCall:
|
|||||||
continue
|
continue
|
||||||
_l.append(
|
_l.append(
|
||||||
{
|
{
|
||||||
"name": f["name"],
|
"name": f.name,
|
||||||
"parameters": f["parameters"],
|
"parameters": f.parameters,
|
||||||
"description": f["description"],
|
"description": f.description,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||||
@@ -178,21 +471,22 @@ class FuncCall:
|
|||||||
func_name = tool["name"]
|
func_name = tool["name"]
|
||||||
args = tool["args"]
|
args = tool["args"]
|
||||||
# 调用函数
|
# 调用函数
|
||||||
tool_callable = None
|
func_tool = self.get_func(func_name)
|
||||||
for func in self.func_list:
|
if not func_tool:
|
||||||
if func.name == func_name:
|
|
||||||
tool_callable = func.star_handler_metadata.handler
|
|
||||||
break
|
|
||||||
if not tool_callable:
|
|
||||||
raise Exception(f"Request function {func_name} not found.")
|
raise Exception(f"Request function {func_name} not found.")
|
||||||
ret = await tool_callable(**args)
|
|
||||||
|
ret = await func_tool.execute(**args)
|
||||||
if ret:
|
if ret:
|
||||||
tool_call_result.append(str(ret))
|
tool_call_result.append(str(ret))
|
||||||
return tool_call_result, True
|
return tool_call_result, True
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.func_list)
|
return str(self.func_list)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.func_list)
|
return str(self.func_list)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
for name in self.mcp_client_dict.keys():
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
||||||
|
|||||||
@@ -1,21 +1,35 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
import asyncio
|
||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||||
from .entites import ProviderType
|
from .entities import ProviderType
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from collections import defaultdict
|
|
||||||
from .register import provider_cls_map, llm_tools
|
from .register import provider_cls_map, llm_tools
|
||||||
from astrbot.core import logger, sp
|
from astrbot.core import logger, sp
|
||||||
|
|
||||||
class ProviderManager():
|
|
||||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
|
||||||
self.providers_config: List = config['provider']
|
|
||||||
self.provider_settings: dict = config['provider_settings']
|
|
||||||
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
|
|
||||||
self.persona_configs: list = config.get('persona', [])
|
|
||||||
|
|
||||||
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
|
class ProviderManager:
|
||||||
|
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||||
|
self.providers_config: List = config["provider"]
|
||||||
|
self.provider_settings: dict = config["provider_settings"]
|
||||||
|
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||||
|
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||||
|
self.persona_configs: list = config.get("persona", [])
|
||||||
|
self.astrbot_config = config
|
||||||
|
|
||||||
|
self.selected_provider_id = sp.get("curr_provider")
|
||||||
|
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||||
|
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||||
|
self.provider_enabled = self.provider_settings.get("enable", False)
|
||||||
|
self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||||
|
self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||||
|
|
||||||
|
# 人格情景管理
|
||||||
|
# 目前没有拆成独立的模块
|
||||||
|
self.default_persona_name = self.provider_settings.get(
|
||||||
|
"default_personality", "default"
|
||||||
|
)
|
||||||
self.personas: List[Personality] = []
|
self.personas: List[Personality] = []
|
||||||
self.selected_default_persona = None
|
self.selected_default_persona = None
|
||||||
for persona in self.persona_configs:
|
for persona in self.persona_configs:
|
||||||
@@ -25,55 +39,74 @@ class ProviderManager():
|
|||||||
mid_processed = ""
|
mid_processed = ""
|
||||||
if begin_dialogs:
|
if begin_dialogs:
|
||||||
if len(begin_dialogs) % 2 != 0:
|
if len(begin_dialogs) % 2 != 0:
|
||||||
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
|
logger.error(
|
||||||
continue
|
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||||
|
)
|
||||||
|
begin_dialogs = []
|
||||||
user_turn = True
|
user_turn = True
|
||||||
for dialog in begin_dialogs:
|
for dialog in begin_dialogs:
|
||||||
bd_processed.append({
|
bd_processed.append(
|
||||||
|
{
|
||||||
"role": "user" if user_turn else "assistant",
|
"role": "user" if user_turn else "assistant",
|
||||||
"content": dialog,
|
"content": dialog,
|
||||||
"_no_save": None # 不持久化到 db
|
"_no_save": None, # 不持久化到 db
|
||||||
})
|
}
|
||||||
|
)
|
||||||
user_turn = not user_turn
|
user_turn = not user_turn
|
||||||
if mood_imitation_dialogs:
|
if mood_imitation_dialogs:
|
||||||
if len(mood_imitation_dialogs) % 2 != 0:
|
if len(mood_imitation_dialogs) % 2 != 0:
|
||||||
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
|
logger.error(
|
||||||
continue
|
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
|
||||||
|
)
|
||||||
|
mood_imitation_dialogs = []
|
||||||
user_turn = True
|
user_turn = True
|
||||||
for dialog in begin_dialogs:
|
for dialog in mood_imitation_dialogs:
|
||||||
role = "A" if user_turn else "B"
|
role = "A" if user_turn else "B"
|
||||||
mid_processed += f"{role}: {dialog}\n"
|
mid_processed += f"{role}: {dialog}\n"
|
||||||
if not user_turn:
|
if not user_turn:
|
||||||
mid_processed += '\n'
|
mid_processed += "\n"
|
||||||
user_turn = not user_turn
|
user_turn = not user_turn
|
||||||
|
|
||||||
try:
|
try:
|
||||||
persona = Personality(
|
persona = Personality(
|
||||||
**persona,
|
**persona,
|
||||||
_begin_dialogs_processed=bd_processed,
|
_begin_dialogs_processed=bd_processed,
|
||||||
_mood_imitation_dialogs_processed=mid_processed
|
_mood_imitation_dialogs_processed=mid_processed,
|
||||||
)
|
)
|
||||||
if persona['name'] == self.default_persona_name:
|
if persona["name"] == self.default_persona_name:
|
||||||
self.selected_default_persona = persona
|
self.selected_default_persona = persona
|
||||||
self.personas.append(persona)
|
self.personas.append(persona)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析 Persona 配置失败:{e}")
|
logger.error(f"解析 Persona 配置失败:{e}")
|
||||||
|
|
||||||
|
if not self.selected_default_persona and len(self.personas) > 0:
|
||||||
|
# 默认选择第一个
|
||||||
|
self.selected_default_persona = self.personas[0]
|
||||||
|
|
||||||
|
if not self.selected_default_persona:
|
||||||
|
self.selected_default_persona = Personality(
|
||||||
|
prompt="You are a helpful and friendly assistant.",
|
||||||
|
name="default",
|
||||||
|
_begin_dialogs_processed=[],
|
||||||
|
_mood_imitation_dialogs_processed="",
|
||||||
|
)
|
||||||
|
self.personas.append(self.selected_default_persona)
|
||||||
|
|
||||||
self.provider_insts: List[Provider] = []
|
self.provider_insts: List[Provider] = []
|
||||||
'''加载的 Provider 的实例'''
|
"""加载的 Provider 的实例"""
|
||||||
self.stt_provider_insts: List[STTProvider] = []
|
self.stt_provider_insts: List[STTProvider] = []
|
||||||
'''加载的 Speech To Text Provider 的实例'''
|
"""加载的 Speech To Text Provider 的实例"""
|
||||||
self.tts_provider_insts: List[TTSProvider] = []
|
self.tts_provider_insts: List[TTSProvider] = []
|
||||||
'''加载的 Text To Speech Provider 的实例'''
|
"""加载的 Text To Speech Provider 的实例"""
|
||||||
|
self.inst_map = {}
|
||||||
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||||
self.llm_tools = llm_tools
|
self.llm_tools = llm_tools
|
||||||
self.curr_provider_inst: Provider = None
|
self.curr_provider_inst: Provider = None
|
||||||
'''当前使用的 Provider 实例'''
|
"""当前使用的 Provider 实例"""
|
||||||
self.curr_stt_provider_inst: STTProvider = None
|
self.curr_stt_provider_inst: STTProvider = None
|
||||||
'''当前使用的 Speech To Text Provider 实例'''
|
"""当前使用的 Speech To Text Provider 实例"""
|
||||||
self.curr_tts_provider_inst: TTSProvider = None
|
self.curr_tts_provider_inst: TTSProvider = None
|
||||||
'''当前使用的 Text To Speech Provider 实例'''
|
"""当前使用的 Text To Speech Provider 实例"""
|
||||||
self.loaded_ids = defaultdict(bool)
|
|
||||||
self.db_helper = db_helper
|
self.db_helper = db_helper
|
||||||
|
|
||||||
# kdb(experimental)
|
# kdb(experimental)
|
||||||
@@ -82,82 +115,155 @@ class ProviderManager():
|
|||||||
if kdb_cfg and len(kdb_cfg):
|
if kdb_cfg and len(kdb_cfg):
|
||||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||||
|
|
||||||
for provider_cfg in self.providers_config:
|
|
||||||
if not provider_cfg['enable']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if provider_cfg['id'] in self.loaded_ids:
|
|
||||||
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
|
|
||||||
self.loaded_ids[provider_cfg['id']] = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
match provider_cfg['type']:
|
|
||||||
case "openai_chat_completion":
|
|
||||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
|
||||||
case "zhipu_chat_completion":
|
|
||||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
|
||||||
case "llm_tuner":
|
|
||||||
logger.info("加载 LLM Tuner 工具 ...")
|
|
||||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
|
||||||
case "dify":
|
|
||||||
from .sources.dify_source import ProviderDify # noqa: F401
|
|
||||||
case "googlegenai_chat_completion":
|
|
||||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
|
||||||
case "openai_whisper_api":
|
|
||||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
|
||||||
case "openai_whisper_selfhost":
|
|
||||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
|
||||||
case "openai_tts_api":
|
|
||||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
|
||||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
|
|
||||||
continue
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
for provider_config in self.providers_config:
|
for provider_config in self.providers_config:
|
||||||
if not provider_config['enable']:
|
await self.load_provider(provider_config)
|
||||||
continue
|
|
||||||
if provider_config['type'] not in provider_cls_map:
|
|
||||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
|
||||||
continue
|
|
||||||
selected_provider_id = sp.get("curr_provider")
|
|
||||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
|
||||||
selected_tts_provider_id = self.provider_settings.get("provider_id")
|
|
||||||
provider_enabled = self.provider_settings.get("enable", False)
|
|
||||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
|
||||||
tts_enabled = self.provider_settings.get("enable", False)
|
|
||||||
|
|
||||||
provider_metadata = provider_cls_map[provider_config['type']]
|
if not self.curr_provider_inst:
|
||||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||||
|
|
||||||
|
if self.stt_enabled and not self.curr_stt_provider_inst:
|
||||||
|
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||||
|
|
||||||
|
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||||
|
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||||
|
|
||||||
|
# 初始化 MCP Client 连接
|
||||||
|
asyncio.create_task(
|
||||||
|
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
||||||
|
)
|
||||||
|
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
||||||
|
|
||||||
|
async def load_provider(self, provider_config: dict):
|
||||||
|
if not provider_config["enable"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动态导入
|
||||||
|
try:
|
||||||
|
match provider_config["type"]:
|
||||||
|
case "openai_chat_completion":
|
||||||
|
from .sources.openai_source import (
|
||||||
|
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||||
|
)
|
||||||
|
case "zhipu_chat_completion":
|
||||||
|
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||||
|
case "anthropic_chat_completion":
|
||||||
|
from .sources.anthropic_source import (
|
||||||
|
ProviderAnthropic as ProviderAnthropic,
|
||||||
|
)
|
||||||
|
case "llm_tuner":
|
||||||
|
logger.info("加载 LLM Tuner 工具 ...")
|
||||||
|
from .sources.llmtuner_source import (
|
||||||
|
LLMTunerModelLoader as LLMTunerModelLoader,
|
||||||
|
)
|
||||||
|
case "dify":
|
||||||
|
from .sources.dify_source import ProviderDify as ProviderDify
|
||||||
|
case "dashscope":
|
||||||
|
from .sources.dashscope_source import (
|
||||||
|
ProviderDashscope as ProviderDashscope,
|
||||||
|
)
|
||||||
|
case "googlegenai_chat_completion":
|
||||||
|
from .sources.gemini_source import (
|
||||||
|
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||||
|
)
|
||||||
|
case "sensevoice_stt_selfhost":
|
||||||
|
from .sources.sensevoice_selfhosted_source import (
|
||||||
|
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||||
|
)
|
||||||
|
case "openai_whisper_api":
|
||||||
|
from .sources.whisper_api_source import (
|
||||||
|
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||||
|
)
|
||||||
|
case "openai_whisper_selfhost":
|
||||||
|
from .sources.whisper_selfhosted_source import (
|
||||||
|
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||||
|
)
|
||||||
|
case "openai_tts_api":
|
||||||
|
from .sources.openai_tts_api_source import (
|
||||||
|
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||||
|
)
|
||||||
|
case "edge_tts":
|
||||||
|
from .sources.edge_tts_source import (
|
||||||
|
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||||
|
)
|
||||||
|
case "gsvi_tts_api":
|
||||||
|
from .sources.gsvi_tts_source import (
|
||||||
|
ProviderGSVITTS as ProviderGSVITTS,
|
||||||
|
)
|
||||||
|
case "fishaudio_tts_api":
|
||||||
|
from .sources.fishaudio_tts_api_source import (
|
||||||
|
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||||
|
)
|
||||||
|
case "dashscope_tts":
|
||||||
|
from .sources.dashscope_tts import (
|
||||||
|
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||||
|
)
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
logger.critical(
|
||||||
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical(
|
||||||
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if provider_config["type"] not in provider_cls_map:
|
||||||
|
logger.error(
|
||||||
|
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
provider_metadata = provider_cls_map[provider_config["type"]]
|
||||||
try:
|
try:
|
||||||
# 按任务实例化提供商
|
# 按任务实例化提供商
|
||||||
|
|
||||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||||
# STT 任务
|
# STT 任务
|
||||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
inst = provider_metadata.cls_type(
|
||||||
|
provider_config, self.provider_settings
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
|
|
||||||
self.stt_provider_insts.append(inst)
|
self.stt_provider_insts.append(inst)
|
||||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
if (
|
||||||
|
self.selected_stt_provider_id == provider_config["id"]
|
||||||
|
and self.stt_enabled
|
||||||
|
):
|
||||||
|
self.curr_stt_provider_inst = inst
|
||||||
|
logger.info(
|
||||||
|
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
||||||
|
)
|
||||||
|
if not self.curr_stt_provider_inst and self.stt_enabled:
|
||||||
self.curr_stt_provider_inst = inst
|
self.curr_stt_provider_inst = inst
|
||||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
|
|
||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||||
# TTS 任务
|
# TTS 任务
|
||||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
inst = provider_metadata.cls_type(
|
||||||
|
provider_config, self.provider_settings
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
|
|
||||||
self.tts_provider_insts.append(inst)
|
self.tts_provider_insts.append(inst)
|
||||||
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
|
if (
|
||||||
|
self.selected_tts_provider_id == provider_config["id"]
|
||||||
|
and self.tts_enabled
|
||||||
|
):
|
||||||
|
self.curr_tts_provider_inst = inst
|
||||||
|
logger.info(
|
||||||
|
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
||||||
|
)
|
||||||
|
if not self.curr_tts_provider_inst and self.tts_enabled:
|
||||||
self.curr_tts_provider_inst = inst
|
self.curr_tts_provider_inst = inst
|
||||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
|
|
||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||||
# 文本生成任务
|
# 文本生成任务
|
||||||
@@ -165,45 +271,116 @@ class ProviderManager():
|
|||||||
provider_config,
|
provider_config,
|
||||||
self.provider_settings,
|
self.provider_settings,
|
||||||
self.db_helper,
|
self.db_helper,
|
||||||
self.provider_settings.get('persistant_history', True),
|
self.provider_settings.get("persistant_history", True),
|
||||||
self.selected_default_persona
|
self.selected_default_persona,
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
|
|
||||||
self.provider_insts.append(inst)
|
self.provider_insts.append(inst)
|
||||||
if selected_provider_id == provider_config['id'] and provider_enabled:
|
if (
|
||||||
|
self.selected_provider_id == provider_config["id"]
|
||||||
|
and self.provider_enabled
|
||||||
|
):
|
||||||
|
self.curr_provider_inst = inst
|
||||||
|
logger.info(
|
||||||
|
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
||||||
|
)
|
||||||
|
if not self.curr_provider_inst and self.provider_enabled:
|
||||||
self.curr_provider_inst = inst
|
self.curr_provider_inst = inst
|
||||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
|
|
||||||
|
|
||||||
|
self.inst_map[provider_config["id"]] = inst
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
logger.error(traceback.format_exc())
|
||||||
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
|
logger.error(
|
||||||
|
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
|
async def reload(self, provider_config: dict):
|
||||||
|
await self.terminate_provider(provider_config["id"])
|
||||||
|
if provider_config["enable"]:
|
||||||
|
await self.load_provider(provider_config)
|
||||||
|
|
||||||
|
# 和配置文件保持同步
|
||||||
|
config_ids = [provider["id"] for provider in self.providers_config]
|
||||||
|
for key in list(self.inst_map.keys()):
|
||||||
|
if key not in config_ids:
|
||||||
|
await self.terminate_provider(key)
|
||||||
|
|
||||||
|
if len(self.provider_insts) == 0:
|
||||||
|
self.curr_provider_inst = None
|
||||||
|
elif (
|
||||||
|
self.curr_provider_inst is None
|
||||||
|
and len(self.provider_insts) > 0
|
||||||
|
and self.provider_enabled
|
||||||
|
):
|
||||||
self.curr_provider_inst = self.provider_insts[0]
|
self.curr_provider_inst = self.provider_insts[0]
|
||||||
|
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||||
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
|
if len(self.stt_provider_insts) == 0:
|
||||||
|
self.curr_stt_provider_inst = None
|
||||||
|
elif (
|
||||||
|
self.curr_stt_provider_inst is None
|
||||||
|
and len(self.stt_provider_insts) > 0
|
||||||
|
and self.stt_enabled
|
||||||
|
):
|
||||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||||
|
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||||
|
logger.info(
|
||||||
|
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
|
if len(self.tts_provider_insts) == 0:
|
||||||
|
self.curr_tts_provider_inst = None
|
||||||
|
elif (
|
||||||
|
self.curr_tts_provider_inst is None
|
||||||
|
and len(self.tts_provider_insts) > 0
|
||||||
|
and self.tts_enabled
|
||||||
|
):
|
||||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||||
|
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||||
if not self.curr_provider_inst:
|
logger.info(
|
||||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||||
|
)
|
||||||
if stt_enabled and not self.curr_stt_provider_inst:
|
|
||||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
|
||||||
|
|
||||||
if tts_enabled and not self.curr_tts_provider_inst:
|
|
||||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
|
||||||
|
|
||||||
|
|
||||||
def get_insts(self):
|
def get_insts(self):
|
||||||
return self.provider_insts
|
return self.provider_insts
|
||||||
|
|
||||||
|
async def terminate_provider(self, provider_id: str):
|
||||||
|
if provider_id in self.inst_map:
|
||||||
|
logger.info(
|
||||||
|
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.inst_map[provider_id] in self.provider_insts:
|
||||||
|
self.provider_insts.remove(self.inst_map[provider_id])
|
||||||
|
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||||
|
self.stt_provider_insts.remove(self.inst_map[provider_id])
|
||||||
|
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||||
|
self.tts_provider_insts.remove(self.inst_map[provider_id])
|
||||||
|
|
||||||
|
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||||
|
self.curr_provider_inst = None
|
||||||
|
if self.inst_map[provider_id] == self.curr_stt_provider_inst:
|
||||||
|
self.curr_stt_provider_inst = None
|
||||||
|
if self.inst_map[provider_id] == self.curr_tts_provider_inst:
|
||||||
|
self.curr_tts_provider_inst = None
|
||||||
|
|
||||||
|
if getattr(self.inst_map[provider_id], "terminate", None):
|
||||||
|
await self.inst_map[provider_id].terminate()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
||||||
|
)
|
||||||
|
del self.inst_map[provider_id]
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
for provider_inst in self.provider_insts:
|
for provider_inst in self.provider_insts:
|
||||||
if hasattr(provider_inst, "terminate"):
|
if hasattr(provider_inst, "terminate"):
|
||||||
await provider_inst.terminate()
|
await provider_inst.terminate()
|
||||||
|
# 清理 MCP Client 连接
|
||||||
|
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
import abc
|
import abc
|
||||||
import json
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core import logger
|
from typing import TypedDict, AsyncGenerator
|
||||||
from typing import TypedDict
|
|
||||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from astrbot.core.provider.entites import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
class Personality(TypedDict):
|
class Personality(TypedDict):
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
name: str = ""
|
name: str = ""
|
||||||
@@ -15,12 +14,12 @@ class Personality(TypedDict):
|
|||||||
mood_imitation_dialogs: List[str] = []
|
mood_imitation_dialogs: List[str] = []
|
||||||
|
|
||||||
# cache
|
# cache
|
||||||
_begin_dialogs_processed: List[dict]
|
_begin_dialogs_processed: List[dict] = []
|
||||||
_mood_imitation_dialogs_processed: str
|
_mood_imitation_dialogs_processed: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProviderMeta():
|
class ProviderMeta:
|
||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
type: str
|
type: str
|
||||||
@@ -33,19 +32,19 @@ class AbstractProvider(abc.ABC):
|
|||||||
self.provider_config = provider_config
|
self.provider_config = provider_config
|
||||||
|
|
||||||
def set_model(self, model_name: str):
|
def set_model(self, model_name: str):
|
||||||
'''设置当前使用的模型名称'''
|
"""设置当前使用的模型名称"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def get_model(self) -> str:
|
def get_model(self) -> str:
|
||||||
'''获得当前使用的模型名称'''
|
"""获得当前使用的模型名称"""
|
||||||
return self.model_name
|
return self.model_name
|
||||||
|
|
||||||
def meta(self) -> ProviderMeta:
|
def meta(self) -> ProviderMeta:
|
||||||
'''获取 Provider 的元数据'''
|
"""获取 Provider 的元数据"""
|
||||||
return ProviderMeta(
|
return ProviderMeta(
|
||||||
id=self.provider_config['id'],
|
id=self.provider_config["id"],
|
||||||
model=self.get_model(),
|
model=self.get_model(),
|
||||||
type=self.provider_config['type']
|
type=self.provider_config["type"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,35 +55,21 @@ class Provider(AbstractProvider):
|
|||||||
provider_settings: dict,
|
provider_settings: dict,
|
||||||
persistant_history: bool = True,
|
persistant_history: bool = True,
|
||||||
db_helper: BaseDatabase = None,
|
db_helper: BaseDatabase = None,
|
||||||
default_persona: Personality = None
|
default_persona: Personality = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(provider_config)
|
super().__init__(provider_config)
|
||||||
|
|
||||||
self.session_memory = defaultdict(list)
|
|
||||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
|
||||||
|
|
||||||
self.provider_settings = provider_settings
|
self.provider_settings = provider_settings
|
||||||
|
|
||||||
self.curr_personality: Personality = default_persona
|
self.curr_personality: Personality = default_persona
|
||||||
'''维护了当前的使用的 persona,即人格。可能为 None'''
|
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||||
|
|
||||||
self.db_helper = db_helper
|
|
||||||
'''用于持久化的数据库操作对象。'''
|
|
||||||
|
|
||||||
if persistant_history:
|
|
||||||
# 读取历史记录
|
|
||||||
try:
|
|
||||||
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
|
|
||||||
self.session_memory[history.session_id] = json.loads(history.content)
|
|
||||||
except BaseException as e:
|
|
||||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_current_key(self) -> str:
|
def get_current_key(self) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_keys(self) -> List[str]:
|
def get_keys(self) -> List[str]:
|
||||||
'''获得提供商 Key'''
|
"""获得提供商 Key"""
|
||||||
return self.provider_config.get("key", [])
|
return self.provider_config.get("key", [])
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -93,58 +78,83 @@ class Provider(AbstractProvider):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_models(self) -> List[str]:
|
def get_models(self) -> List[str]:
|
||||||
'''获得支持的模型列表'''
|
"""获得支持的模型列表"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
|
async def text_chat(
|
||||||
'''获取人类可读的上下文
|
self,
|
||||||
|
|
||||||
page 从 1 开始
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
["User: 你好", "Assistant: 你好!"]
|
|
||||||
|
|
||||||
Return:
|
|
||||||
contexts: List[str]: 上下文列表
|
|
||||||
total_pages: int: 总页数
|
|
||||||
'''
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def text_chat(self,
|
|
||||||
prompt: str,
|
prompt: str,
|
||||||
session_id: str=None,
|
session_id: str = None,
|
||||||
image_urls: List[str]=None,
|
image_urls: List[str] = None,
|
||||||
func_tool: FuncCall=None,
|
func_tool: FuncCall = None,
|
||||||
contexts: List=None,
|
contexts: List = None,
|
||||||
system_prompt: str=None,
|
system_prompt: str = None,
|
||||||
**kwargs) -> LLMResponse:
|
tool_calls_result: ToolCallsResult = None,
|
||||||
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: 提示词
|
prompt: 提示词
|
||||||
session_id: 会话 ID
|
session_id: 会话 ID(此属性已经被废弃)
|
||||||
image_urls: 图片 URL 列表
|
image_urls: 图片 URL 列表
|
||||||
tools: Function-calling 工具
|
tools: Function-calling 工具
|
||||||
contexts: 上下文
|
contexts: 上下文
|
||||||
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||||
kwargs: 其他参数
|
kwargs: 其他参数
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
|
|
||||||
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
|
|
||||||
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
|
|
||||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||||
'''
|
"""
|
||||||
raise NotImplementedError()
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
async def text_chat_stream(
|
||||||
async def forget(self, session_id: str) -> bool:
|
self,
|
||||||
'''重置某一个 session_id 的上下文'''
|
prompt: str,
|
||||||
raise NotImplementedError()
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = None,
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts: List = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
tool_calls_result: ToolCallsResult = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
|
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
session_id: 会话 ID(此属性已经被废弃)
|
||||||
|
image_urls: 图片 URL 列表
|
||||||
|
tools: Function-calling 工具
|
||||||
|
contexts: 上下文
|
||||||
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||||
|
kwargs: 其他参数
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||||
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def pop_record(self, context: List):
|
||||||
|
"""
|
||||||
|
弹出 context 第一条非系统提示词对话记录
|
||||||
|
"""
|
||||||
|
poped = 0
|
||||||
|
indexs_to_pop = []
|
||||||
|
for idx, record in enumerate(context):
|
||||||
|
if record["role"] == "system":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
indexs_to_pop.append(idx)
|
||||||
|
poped += 1
|
||||||
|
if poped == 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
for idx in reversed(indexs_to_pop):
|
||||||
|
context.pop(idx)
|
||||||
|
|
||||||
|
|
||||||
class STTProvider(AbstractProvider):
|
class STTProvider(AbstractProvider):
|
||||||
@@ -155,7 +165,7 @@ class STTProvider(AbstractProvider):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_text(self, audio_url: str) -> str:
|
async def get_text(self, audio_url: str) -> str:
|
||||||
'''获取音频的文本'''
|
"""获取音频的文本"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@@ -167,5 +177,5 @@ class TTSProvider(AbstractProvider):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_audio(self, text: str) -> str:
|
async def get_audio(self, text: str) -> str:
|
||||||
'''获取文本的音频,返回音频文件路径'''
|
"""获取文本的音频,返回音频文件路径"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -1,35 +1,39 @@
|
|||||||
from typing import List, Dict, Type
|
from typing import List, Dict
|
||||||
from .entites import ProviderMetaData, ProviderType
|
from .entities import ProviderMetaData, ProviderType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from .func_tool_manager import FuncCall
|
from .func_tool_manager import FuncCall
|
||||||
|
|
||||||
provider_registry: List[ProviderMetaData] = []
|
provider_registry: List[ProviderMetaData] = []
|
||||||
'''维护了通过装饰器注册的 Provider'''
|
"""维护了通过装饰器注册的 Provider"""
|
||||||
provider_cls_map: Dict[str, ProviderMetaData] = {}
|
provider_cls_map: Dict[str, ProviderMetaData] = {}
|
||||||
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
|
"""维护了 Provider 类型名称和 ProviderMetadata 的映射"""
|
||||||
|
|
||||||
llm_tools = FuncCall()
|
llm_tools = FuncCall()
|
||||||
|
|
||||||
|
|
||||||
def register_provider_adapter(
|
def register_provider_adapter(
|
||||||
provider_type_name: str,
|
provider_type_name: str,
|
||||||
desc: str,
|
desc: str,
|
||||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
|
||||||
default_config_tmpl: dict = None,
|
default_config_tmpl: dict = None,
|
||||||
provider_display_name: str = None
|
provider_display_name: str = None,
|
||||||
):
|
):
|
||||||
'''用于注册平台适配器的带参装饰器'''
|
"""用于注册平台适配器的带参装饰器"""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
if provider_type_name in provider_cls_map:
|
if provider_type_name in provider_cls_map:
|
||||||
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
|
raise ValueError(
|
||||||
|
f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。"
|
||||||
|
)
|
||||||
|
|
||||||
# 添加必备选项
|
# 添加必备选项
|
||||||
if default_config_tmpl:
|
if default_config_tmpl:
|
||||||
if 'type' not in default_config_tmpl:
|
if "type" not in default_config_tmpl:
|
||||||
default_config_tmpl['type'] = provider_type_name
|
default_config_tmpl["type"] = provider_type_name
|
||||||
if 'enable' not in default_config_tmpl:
|
if "enable" not in default_config_tmpl:
|
||||||
default_config_tmpl['enable'] = False
|
default_config_tmpl["enable"] = False
|
||||||
if 'id' not in default_config_tmpl:
|
if "id" not in default_config_tmpl:
|
||||||
default_config_tmpl['id'] = provider_type_name
|
default_config_tmpl["id"] = provider_type_name
|
||||||
|
|
||||||
pm = ProviderMetaData(
|
pm = ProviderMetaData(
|
||||||
type=provider_type_name,
|
type=provider_type_name,
|
||||||
@@ -37,7 +41,7 @@ def register_provider_adapter(
|
|||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
cls_type=cls,
|
cls_type=cls,
|
||||||
default_config_tmpl=default_config_tmpl,
|
default_config_tmpl=default_config_tmpl,
|
||||||
provider_display_name=provider_display_name
|
provider_display_name=provider_display_name,
|
||||||
)
|
)
|
||||||
provider_registry.append(pm)
|
provider_registry.append(pm)
|
||||||
provider_cls_map[provider_type_name] = pm
|
provider_cls_map[provider_type_name] = pm
|
||||||
|
|||||||
232
astrbot/core/provider/sources/anthropic_source.py
Normal file
232
astrbot/core/provider/sources/anthropic_source.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
from typing import List
|
||||||
|
from mimetypes import guess_type
|
||||||
|
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
from anthropic.types import Message
|
||||||
|
|
||||||
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.api.provider import Provider, Personality
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
|
from ..register import register_provider_adapter
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||||
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
|
|
||||||
|
|
||||||
|
@register_provider_adapter(
|
||||||
|
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
|
||||||
|
)
|
||||||
|
class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_config: dict,
|
||||||
|
provider_settings: dict,
|
||||||
|
db_helper: BaseDatabase,
|
||||||
|
persistant_history=True,
|
||||||
|
default_persona: Personality = None,
|
||||||
|
) -> None:
|
||||||
|
# Skip OpenAI's __init__ and call Provider's __init__ directly
|
||||||
|
Provider.__init__(
|
||||||
|
self,
|
||||||
|
provider_config,
|
||||||
|
provider_settings,
|
||||||
|
persistant_history,
|
||||||
|
db_helper,
|
||||||
|
default_persona,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chosen_api_key = None
|
||||||
|
self.api_keys: List = provider_config.get("key", [])
|
||||||
|
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||||
|
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||||
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
|
if isinstance(self.timeout, str):
|
||||||
|
self.timeout = int(self.timeout)
|
||||||
|
|
||||||
|
self.client = AsyncAnthropic(
|
||||||
|
api_key=self.chosen_api_key, timeout=self.timeout, base_url=self.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_model(provider_config["model_config"]["model"])
|
||||||
|
|
||||||
|
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||||
|
if tools:
|
||||||
|
tool_list = tools.get_func_desc_anthropic_style()
|
||||||
|
if tool_list:
|
||||||
|
payloads["tools"] = tool_list
|
||||||
|
|
||||||
|
completion = await self.client.messages.create(**payloads, stream=False)
|
||||||
|
|
||||||
|
assert isinstance(completion, Message)
|
||||||
|
logger.debug(f"completion: {completion}")
|
||||||
|
|
||||||
|
if len(completion.content) == 0:
|
||||||
|
raise Exception("API 返回的 completion 为空。")
|
||||||
|
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
|
||||||
|
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
|
||||||
|
content = completion.content[-1]
|
||||||
|
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
|
||||||
|
if content.type == "text":
|
||||||
|
# text completion
|
||||||
|
completion_text = str(content.text).strip()
|
||||||
|
# llm_response.completion_text = completion_text
|
||||||
|
llm_response.result_chain = MessageChain().message(completion_text)
|
||||||
|
|
||||||
|
# Anthropic每次只返回一个函数调用
|
||||||
|
if completion.stop_reason == "tool_use":
|
||||||
|
# tools call (function calling)
|
||||||
|
args_ls = []
|
||||||
|
func_name_ls = []
|
||||||
|
tool_use_ids = []
|
||||||
|
func_name_ls.append(content.name)
|
||||||
|
args_ls.append(content.input)
|
||||||
|
tool_use_ids.append(content.id)
|
||||||
|
llm_response.role = "tool"
|
||||||
|
llm_response.tools_call_args = args_ls
|
||||||
|
llm_response.tools_call_name = func_name_ls
|
||||||
|
llm_response.tools_call_ids = tool_use_ids
|
||||||
|
|
||||||
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
|
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
|
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
|
|
||||||
|
llm_response.raw_completion = completion
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = [],
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts=[],
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result: ToolCallsResult = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
if not prompt:
|
||||||
|
prompt = "<image>"
|
||||||
|
|
||||||
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
|
context_query = [*contexts, new_record]
|
||||||
|
|
||||||
|
for part in context_query:
|
||||||
|
if "_no_save" in part:
|
||||||
|
del part["_no_save"]
|
||||||
|
|
||||||
|
if tool_calls_result:
|
||||||
|
# 暂时这样写。
|
||||||
|
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
|
||||||
|
|
||||||
|
model_config = self.provider_config.get("model_config", {})
|
||||||
|
|
||||||
|
payloads = {"messages": context_query, **model_config}
|
||||||
|
# Anthropic has a different way of handling system prompts
|
||||||
|
if system_prompt:
|
||||||
|
payloads["system"] = system_prompt
|
||||||
|
|
||||||
|
llm_response = None
|
||||||
|
try:
|
||||||
|
llm_response = await self._query(payloads, func_tool)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if "maximum context length" in str(e):
|
||||||
|
retry_cnt = 20
|
||||||
|
while retry_cnt > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self.pop_record(context_query)
|
||||||
|
response = await self.client.messages.create(
|
||||||
|
messages=context_query, **model_config
|
||||||
|
)
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
||||||
|
llm_response.raw_completion = response
|
||||||
|
return llm_response
|
||||||
|
except Exception as e:
|
||||||
|
if "maximum context length" in str(e):
|
||||||
|
retry_cnt -= 1
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
||||||
|
else:
|
||||||
|
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=...,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=...,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
# 调用 text_chat 模拟流式
|
||||||
|
llm_response = await self.text_chat(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
)
|
||||||
|
llm_response.is_chunk = True
|
||||||
|
yield llm_response
|
||||||
|
llm_response.is_chunk = False
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||||
|
"""组装上下文,支持文本和图片"""
|
||||||
|
if not image_urls:
|
||||||
|
return {"role": "user", "content": text}
|
||||||
|
|
||||||
|
content = []
|
||||||
|
content.append({"type": "text", "text": text})
|
||||||
|
|
||||||
|
for image_url in image_urls:
|
||||||
|
if image_url.startswith("http"):
|
||||||
|
image_path = await download_image_by_url(image_url)
|
||||||
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
|
elif image_url.startswith("file:///"):
|
||||||
|
image_path = image_url.replace("file:///", "")
|
||||||
|
image_data = await self.encode_image_bs64(image_path)
|
||||||
|
else:
|
||||||
|
image_data = await self.encode_image_bs64(image_url)
|
||||||
|
|
||||||
|
if not image_data:
|
||||||
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get mime type for the image
|
||||||
|
mime_type, _ = guess_type(image_url)
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||||
|
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": mime_type,
|
||||||
|
"data": image_data.split("base64,")[1]
|
||||||
|
if "base64," in image_data
|
||||||
|
else image_data,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"role": "user", "content": content}
|
||||||
203
astrbot/core/provider/sources/dashscope_source.py
Normal file
203
astrbot/core/provider/sources/dashscope_source.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from typing import List
|
||||||
|
from .. import Provider, Personality
|
||||||
|
from ..entities import LLMResponse
|
||||||
|
from ..func_tool_manager import FuncCall
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from ..register import register_provider_adapter
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
|
from astrbot.core import logger, sp
|
||||||
|
from dashscope import Application
|
||||||
|
|
||||||
|
|
||||||
|
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||||
|
class ProviderDashscope(ProviderOpenAIOfficial):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_config: dict,
|
||||||
|
provider_settings: dict,
|
||||||
|
db_helper: BaseDatabase,
|
||||||
|
persistant_history=False,
|
||||||
|
default_persona: Personality = None,
|
||||||
|
) -> None:
|
||||||
|
Provider.__init__(
|
||||||
|
self,
|
||||||
|
provider_config,
|
||||||
|
provider_settings,
|
||||||
|
persistant_history,
|
||||||
|
db_helper,
|
||||||
|
default_persona,
|
||||||
|
)
|
||||||
|
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||||
|
if not self.api_key:
|
||||||
|
raise Exception("阿里云百炼 API Key 不能为空。")
|
||||||
|
self.app_id = provider_config.get("dashscope_app_id", "")
|
||||||
|
if not self.app_id:
|
||||||
|
raise Exception("阿里云百炼 APP ID 不能为空。")
|
||||||
|
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
||||||
|
if not self.dashscope_app_type:
|
||||||
|
raise Exception("阿里云百炼 APP 类型不能为空。")
|
||||||
|
self.model_name = "dashscope"
|
||||||
|
self.variables: dict = provider_config.get("variables", {})
|
||||||
|
self.rag_options: dict = provider_config.get("rag_options", {})
|
||||||
|
self.output_reference = self.rag_options.get("output_reference", False)
|
||||||
|
self.rag_options = self.rag_options.copy()
|
||||||
|
self.rag_options.pop("output_reference", None)
|
||||||
|
|
||||||
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
|
if isinstance(self.timeout, str):
|
||||||
|
self.timeout = int(self.timeout)
|
||||||
|
|
||||||
|
def has_rag_options(self):
|
||||||
|
"""判断是否有 RAG 选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否有 RAG 选项
|
||||||
|
"""
|
||||||
|
if self.rag_options and (
|
||||||
|
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||||
|
or len(self.rag_options.get("file_ids", [])) > 0
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def text_chat(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = [],
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts: List = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
# 获得会话变量
|
||||||
|
payload_vars = self.variables.copy()
|
||||||
|
# 动态变量
|
||||||
|
session_vars = sp.get("session_variables", {})
|
||||||
|
session_var = session_vars.get(session_id, {})
|
||||||
|
payload_vars.update(session_var)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||||
|
and not self.has_rag_options()
|
||||||
|
):
|
||||||
|
# 支持多轮对话的
|
||||||
|
new_record = {"role": "user", "content": prompt}
|
||||||
|
if image_urls:
|
||||||
|
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
||||||
|
contexts_no_img = await self._remove_image_from_context(contexts)
|
||||||
|
context_query = [*contexts_no_img, new_record]
|
||||||
|
if system_prompt:
|
||||||
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||||
|
for part in context_query:
|
||||||
|
if "_no_save" in part:
|
||||||
|
del part["_no_save"]
|
||||||
|
# 调用阿里云百炼 API
|
||||||
|
payload = {
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"messages": context_query,
|
||||||
|
"biz_params": payload_vars or None,
|
||||||
|
}
|
||||||
|
partial = functools.partial(
|
||||||
|
Application.call,
|
||||||
|
**payload,
|
||||||
|
)
|
||||||
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||||
|
else:
|
||||||
|
# 不支持多轮对话的
|
||||||
|
# 调用阿里云百炼 API
|
||||||
|
payload = {
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"prompt": prompt,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"biz_params": payload_vars or None,
|
||||||
|
}
|
||||||
|
if self.rag_options:
|
||||||
|
payload["rag_options"] = self.rag_options
|
||||||
|
partial = functools.partial(
|
||||||
|
Application.call,
|
||||||
|
**payload,
|
||||||
|
)
|
||||||
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||||
|
|
||||||
|
logger.debug(f"dashscope resp: {response}")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
role="err",
|
||||||
|
result_chain=MessageChain().message(
|
||||||
|
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output_text = response.output.get("text", "")
|
||||||
|
# RAG 引用脚标格式化
|
||||||
|
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||||
|
if self.output_reference and response.output.get("doc_references", None):
|
||||||
|
ref_str = ""
|
||||||
|
for ref in response.output.get("doc_references", []):
|
||||||
|
ref_title = (
|
||||||
|
ref.get("title", "")
|
||||||
|
if ref.get("title")
|
||||||
|
else ref.get("doc_name", "")
|
||||||
|
)
|
||||||
|
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
||||||
|
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||||
|
|
||||||
|
llm_response = LLMResponse("assistant")
|
||||||
|
llm_response.result_chain = MessageChain().message(output_text)
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
async def text_chat_stream(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
session_id=None,
|
||||||
|
image_urls=...,
|
||||||
|
func_tool=None,
|
||||||
|
contexts=...,
|
||||||
|
system_prompt=None,
|
||||||
|
tool_calls_result=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# raise NotImplementedError("This method is not implemented yet.")
|
||||||
|
# 调用 text_chat 模拟流式
|
||||||
|
llm_response = await self.text_chat(
|
||||||
|
prompt=prompt,
|
||||||
|
session_id=session_id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
func_tool=func_tool,
|
||||||
|
contexts=contexts,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tool_calls_result=tool_calls_result,
|
||||||
|
)
|
||||||
|
llm_response.is_chunk = True
|
||||||
|
yield llm_response
|
||||||
|
llm_response.is_chunk = False
|
||||||
|
yield llm_response
|
||||||
|
|
||||||
|
async def forget(self, session_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_current_key(self):
|
||||||
|
return self.api_key
|
||||||
|
|
||||||
|
async def set_key(self, key):
|
||||||
|
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
|
||||||
|
|
||||||
|
async def get_models(self):
|
||||||
|
return [self.get_model()]
|
||||||
|
|
||||||
|
async def get_human_readable_context(self, session_id, page, page_size):
|
||||||
|
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
pass
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user