Compare commits
904 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9564166297 | ||
|
|
f5cf3c3c8e | ||
|
|
18f919fb6b | ||
|
|
0924835253 | ||
|
|
20d2e5c578 | ||
|
|
907801605c | ||
|
|
93bc684e8c | ||
|
|
a76c98d57e | ||
|
|
d937a800d0 | ||
|
|
d16f3a227f | ||
|
|
80c9a3eeda | ||
|
|
e68173b451 | ||
|
|
40c27d87f5 | ||
|
|
3c13b5049d | ||
|
|
8288d5e51f | ||
|
|
4ffbb18ab4 | ||
|
|
b27271b7a3 | ||
|
|
ebb6665f64 | ||
|
|
e4e5731ffd | ||
|
|
2ab5810f13 | ||
|
|
af934c5d09 | ||
|
|
1e0cf7c112 | ||
|
|
46859c93c9 | ||
|
|
1641549016 | ||
|
|
716a5dbb8a | ||
|
|
af98cb11c5 | ||
|
|
9a4c2cf341 | ||
|
|
2bc3bcd102 | ||
|
|
d6c663f79d | ||
|
|
2b4ee13b5e | ||
|
|
6959f86632 | ||
|
|
537d373e10 | ||
|
|
cceadf222c | ||
|
|
cf5a4af623 | ||
|
|
39aea11c22 | ||
|
|
c2f1227700 | ||
|
|
900f14d37c | ||
|
|
598249b1d6 | ||
|
|
7ed15bdf04 | ||
|
|
2fc0ec0f72 | ||
|
|
5e9c2a669b | ||
|
|
b310521884 | ||
|
|
288945bf7e | ||
|
|
4fc07cff36 | ||
|
|
b884fe0e86 | ||
|
|
855858c236 | ||
|
|
c11a2a5419 | ||
|
|
773a6572af | ||
|
|
88ad373c9b | ||
|
|
51666464b9 | ||
|
|
5af9cf2f52 | ||
|
|
12c4ae4b10 | ||
|
|
4e1bef414a | ||
|
|
e896c18644 | ||
|
|
c852685e74 | ||
|
|
1e99797df8 | ||
|
|
52a4c986a8 | ||
|
|
c501728204 | ||
|
|
6b067fa6a7 | ||
|
|
a1cd5c53a9 | ||
|
|
a46d487e03 | ||
|
|
3deb6d3ab3 | ||
|
|
af34cdd5d2 | ||
|
|
6e1393235a | ||
|
|
343e0b54b9 | ||
|
|
ecb70cb6f7 | ||
|
|
ca50618af6 | ||
|
|
29c07ba83e | ||
|
|
45fbb83a9f | ||
|
|
ae7ba2df25 | ||
|
|
c3ef57cc32 | ||
|
|
7bb4ca5a14 | ||
|
|
063783d81d | ||
|
|
42116c9b65 | ||
|
|
a36e11973d | ||
|
|
5125568ea2 | ||
|
|
0fa164e50d | ||
|
|
cf814e81ee | ||
|
|
43a45f18ce | ||
|
|
ad51381063 | ||
|
|
0b0e4ce904 | ||
|
|
6a3e04d688 | ||
|
|
4107a17370 | ||
|
|
06b4d8f169 | ||
|
|
1c0c820746 | ||
|
|
d061403a28 | ||
|
|
5c092321a6 | ||
|
|
bdd3f61c1f | ||
|
|
8023557d6e | ||
|
|
074b0ced7a | ||
|
|
3864b1ac9b | ||
|
|
6e9b43457d | ||
|
|
ca1aec8920 | ||
|
|
acac580862 | ||
|
|
673e1b2980 | ||
|
|
f62157be72 | ||
|
|
f894ecf3b6 | ||
|
|
66dd4e28ad | ||
|
|
939dc1b0fb | ||
|
|
56bf5d38a1 | ||
|
|
d09b70b295 | ||
|
|
205180387a | ||
|
|
39c8cfeda5 | ||
|
|
f38a329be5 | ||
|
|
a0cd069539 | ||
|
|
bf306a2f01 | ||
|
|
c31f93a8d1 | ||
|
|
4730ab6309 | ||
|
|
1ae78ca98c | ||
|
|
d2379da478 | ||
|
|
0f64981b20 | ||
|
|
0002e49bb5 | ||
|
|
db13a60274 | ||
|
|
db0f11a359 | ||
|
|
ac7f43520b | ||
|
|
f67b9f5f6e | ||
|
|
c75156c4ce | ||
|
|
10270b5595 | ||
|
|
f7458572ed | ||
|
|
d57b7222b2 | ||
|
|
62e70a673a | ||
|
|
5e9eba6478 | ||
|
|
cb02dfe1a4 | ||
|
|
b50739e1af | ||
|
|
8da1b0212d | ||
|
|
ca1f2acb33 | ||
|
|
c15f966669 | ||
|
|
7705b8781a | ||
|
|
b2502746f0 | ||
|
|
ab68094386 | ||
|
|
bbec701223 | ||
|
|
b29d14e600 | ||
|
|
86e51c5cd1 | ||
|
|
cb8267be3f | ||
|
|
eaed43915c | ||
|
|
bd91fd2c38 | ||
|
|
1203b214cd | ||
|
|
c3fec15f11 | ||
|
|
0545653494 | ||
|
|
db2989bdb4 | ||
|
|
587bd00a19 | ||
|
|
960ff438e8 | ||
|
|
98e7ea85d3 | ||
|
|
2549e44710 | ||
|
|
4d32b563ca | ||
|
|
3a4b732977 | ||
|
|
500909a28e | ||
|
|
07753eb25b | ||
|
|
c6eaf3d010 | ||
|
|
6723fe8271 | ||
|
|
3348b70435 | ||
|
|
35a8527c16 | ||
|
|
7afc475290 | ||
|
|
789bceaa3a | ||
|
|
abbc043969 | ||
|
|
654e5762f1 | ||
|
|
507c3e3629 | ||
|
|
991dfeb2f2 | ||
|
|
26482fc2d3 | ||
|
|
e0ce6d9688 | ||
|
|
946595216a | ||
|
|
864b6bc56d | ||
|
|
6ea5b7581f | ||
|
|
f70b8f0c10 | ||
|
|
1593bcb537 | ||
|
|
bf7fc02c8d | ||
|
|
143702b92b | ||
|
|
c5ccc1a084 | ||
|
|
2ecb52a9b2 | ||
|
|
6439917cbe | ||
|
|
d21c18f657 | ||
|
|
25ef0039e4 | ||
|
|
e6981290bc | ||
|
|
75c3d8abbd | ||
|
|
d88683f498 | ||
|
|
40b9aa3a4c | ||
|
|
b6d1515d58 | ||
|
|
e01d4264e3 | ||
|
|
2117b65487 | ||
|
|
a7823b352f | ||
|
|
c543b62a08 | ||
|
|
3923b87f08 | ||
|
|
b7ecdadb83 | ||
|
|
5ff121e1ed | ||
|
|
f486e5448f | ||
|
|
c5aae98558 | ||
|
|
6d8a3b9897 | ||
|
|
6d98780e19 | ||
|
|
3ad2c46f3f | ||
|
|
a730cee7fd | ||
|
|
77c823c100 | ||
|
|
124f21c67a | ||
|
|
e46cf20dd3 | ||
|
|
4bef5e8313 | ||
|
|
22e93b0af4 | ||
|
|
5aeca9662b | ||
|
|
b996cf1f05 | ||
|
|
878a106877 | ||
|
|
45d36f86fd | ||
|
|
b108ae403a | ||
|
|
887ed66768 | ||
|
|
dac840a887 | ||
|
|
238de4ba8c | ||
|
|
9a7bdade43 | ||
|
|
aa84556204 | ||
|
|
6b68069fcd | ||
|
|
42c7034fb2 | ||
|
|
060c7e0145 | ||
|
|
b5b085dfb1 | ||
|
|
fc06ce9d7f | ||
|
|
d8d81b05a7 | ||
|
|
a60f42b1f2 | ||
|
|
6e18be88d0 | ||
|
|
b45e439c48 | ||
|
|
b87061c18c | ||
|
|
f78aca7752 | ||
|
|
3ccca2aa10 | ||
|
|
6d7c40eb76 | ||
|
|
da4cd7fb65 | ||
|
|
c97cda6b84 | ||
|
|
7a7fd4167a | ||
|
|
dffc1a43d5 | ||
|
|
36897fea1e | ||
|
|
c7b34735f0 | ||
|
|
5b07176c88 | ||
|
|
474b40d660 | ||
|
|
a62901b948 | ||
|
|
25d8746327 | ||
|
|
aff1698223 | ||
|
|
7f8941745f | ||
|
|
b858401098 | ||
|
|
d5a158b80f | ||
|
|
f315f284aa | ||
|
|
c367f5009d | ||
|
|
6db1e63bda | ||
|
|
e22ab2ede6 | ||
|
|
b7d7e0b682 | ||
|
|
96bba15f2f | ||
|
|
fcf965a595 | ||
|
|
e1a20d3c22 | ||
|
|
2abd7d8c5d | ||
|
|
5b8f73cdd7 | ||
|
|
7fd765421f | ||
|
|
d9d94af022 | ||
|
|
790b924e57 | ||
|
|
4a62f877df | ||
|
|
ac47c57bb7 | ||
|
|
3ace4199a1 | ||
|
|
e6bd7524c1 | ||
|
|
699c86e8c1 | ||
|
|
f40fa0ecea | ||
|
|
626f94686b | ||
|
|
752d13b1b1 | ||
|
|
54c0dc1b2b | ||
|
|
c5bc709898 | ||
|
|
ccdbb01513 | ||
|
|
5206d750ac | ||
|
|
a800e3df67 | ||
|
|
ccb1f87a20 | ||
|
|
c111da4681 | ||
|
|
9cc4e97a53 | ||
|
|
dca1c0b0f3 | ||
|
|
f06be6ed21 | ||
|
|
3c8ec2f42e | ||
|
|
7e193f7f52 | ||
|
|
7069b02929 | ||
|
|
66995db927 | ||
|
|
c36054ca1b | ||
|
|
3e07fbf3dc | ||
|
|
bf3fbe3e96 | ||
|
|
0a93d22bc8 | ||
|
|
f5b3d94d16 | ||
|
|
4d1a6994aa | ||
|
|
05c686782c | ||
|
|
85609ea742 | ||
|
|
20dabc0615 | ||
|
|
356dd9bc2b | ||
|
|
cd5d7534c4 | ||
|
|
b4f12fc933 | ||
|
|
cbea387ce0 | ||
|
|
345b155374 | ||
|
|
29d216950e | ||
|
|
321b04772c | ||
|
|
5b924aee98 | ||
|
|
46d44e3405 | ||
|
|
4d5332fe25 | ||
|
|
18bd4c54f4 | ||
|
|
31c7768ca0 | ||
|
|
6ec643e9d1 | ||
|
|
2b39f6f61c | ||
|
|
bf3ca13961 | ||
|
|
82026370ec | ||
|
|
6d49bf5346 | ||
|
|
67431d87fb | ||
|
|
fdf55221e6 | ||
|
|
07f277dd3b | ||
|
|
cf8f0603ca | ||
|
|
5592408ab8 | ||
|
|
a01617b45c | ||
|
|
7abb4087b3 | ||
|
|
dff15cf27a | ||
|
|
aa858137e5 | ||
|
|
45cb143202 | ||
|
|
7a9c6ab8c4 | ||
|
|
e2c26c292d | ||
|
|
be7c3fd00e | ||
|
|
7e5461a2cf | ||
|
|
6ee9010645 | ||
|
|
a23d5be056 | ||
|
|
97a6a1fdc2 | ||
|
|
c8f567347b | ||
|
|
74c1e7f69e | ||
|
|
15a5fc0cae | ||
|
|
f07c54d47c | ||
|
|
70446be108 | ||
|
|
d6d21fca56 | ||
|
|
8d7273924f | ||
|
|
ea64afbaa7 | ||
|
|
45da9837ec | ||
|
|
8c19b7d163 | ||
|
|
ab227a08d0 | ||
|
|
40d6e77964 | ||
|
|
9326e3f1b0 | ||
|
|
0e1eb3daf6 | ||
|
|
05daac12ed | ||
|
|
c5b24b4764 | ||
|
|
cc16548e5f | ||
|
|
291d65bb3e | ||
|
|
bd3ad03da6 | ||
|
|
5fa6788357 | ||
|
|
c5c5a98ac4 | ||
|
|
a1151143cf | ||
|
|
f5024984f7 | ||
|
|
f4880fd90d | ||
|
|
0ae61d5865 | ||
|
|
d3bd775a79 | ||
|
|
da546cfe7f | ||
|
|
a211933e83 | ||
|
|
1d40b5a821 | ||
|
|
33836daeb7 | ||
|
|
d921b0f6bd | ||
|
|
0607b95df6 | ||
|
|
0de6d0e046 | ||
|
|
98427345cf | ||
|
|
9fedaa9f77 | ||
|
|
bf4c2ecd33 | ||
|
|
f8c18cc1e0 | ||
|
|
458b900412 | ||
|
|
192c776e0b | ||
|
|
5cdec18863 | ||
|
|
15f856f951 | ||
|
|
01d52cef74 | ||
|
|
95563c8659 | ||
|
|
31d8c40eca | ||
|
|
56001ed272 | ||
|
|
d916fda04c | ||
|
|
cfae655068 | ||
|
|
5596565ec4 | ||
|
|
afa1aa5d93 | ||
|
|
e98c3d8393 | ||
|
|
6687b816f0 | ||
|
|
ea8035e854 | ||
|
|
54b0171d49 | ||
|
|
676d4277b9 | ||
|
|
a4b1da3ca2 | ||
|
|
9e9c16e770 | ||
|
|
dc87006fed | ||
|
|
b9b260f26a | ||
|
|
33fd6a5016 | ||
|
|
97cbccc2ba | ||
|
|
1ee4685d5d | ||
|
|
aba18232b1 | ||
|
|
0a02441b75 | ||
|
|
1be5b4c7ff | ||
|
|
a0ce0cf18a | ||
|
|
7c54e5d093 | ||
|
|
b825e51dab | ||
|
|
589855c393 | ||
|
|
4c546f2f53 | ||
|
|
3753fce912 | ||
|
|
4c02857ec5 | ||
|
|
33f87ff7d7 | ||
|
|
784dcf2a9a | ||
|
|
43ee943acb | ||
|
|
a769fd7d13 | ||
|
|
2c4fd00b16 | ||
|
|
264771fe98 | ||
|
|
ecd92dafef | ||
|
|
c8b6e4bea3 | ||
|
|
3756cb766e | ||
|
|
068d9ca60b | ||
|
|
93f632d8b8 | ||
|
|
bb44ce7e74 | ||
|
|
6986c8d8f7 | ||
|
|
fe95506db4 | ||
|
|
310ed76b18 | ||
|
|
98830d147f | ||
|
|
19c9177d7b | ||
|
|
f41c5f97f6 | ||
|
|
648c125697 | ||
|
|
0dc2b89897 | ||
|
|
83745f83a5 | ||
|
|
2f91fe4535 | ||
|
|
739f09059e | ||
|
|
c86f9f0f5f | ||
|
|
9470ca6bc5 | ||
|
|
2a92c4d5de | ||
|
|
bb6e892657 | ||
|
|
c9079b9299 | ||
|
|
b6963c1bf9 | ||
|
|
9c29df47bb | ||
|
|
fc146d3d00 | ||
|
|
1bf5a21678 | ||
|
|
011542dc2b | ||
|
|
489784104e | ||
|
|
3860634fd2 | ||
|
|
709c324e18 | ||
|
|
b75d24d92c | ||
|
|
ed80e9424c | ||
|
|
2fe1f2060a | ||
|
|
c6df820164 | ||
|
|
d6239822db | ||
|
|
bced9ffff9 | ||
|
|
d7d1c1544a | ||
|
|
7c1e8ce48c | ||
|
|
e3b0ca8ef6 | ||
|
|
9e266eb6d5 | ||
|
|
7231403e16 | ||
|
|
344a486fd7 | ||
|
|
4fd831875d | ||
|
|
0988d067ea | ||
|
|
44dbe475af | ||
|
|
bd24cf3ea4 | ||
|
|
b493a808fe | ||
|
|
54035d108d | ||
|
|
c5e8bc7e20 | ||
|
|
3bbb4779a3 | ||
|
|
1b3963ebea | ||
|
|
3b6dd7e15a | ||
|
|
757d2a3947 | ||
|
|
61b71143f2 | ||
|
|
1b343a36c9 | ||
|
|
8e94937060 | ||
|
|
e8ffebc006 | ||
|
|
2ca95eaa9f | ||
|
|
0dc5b4cdfc | ||
|
|
cc6cd96d8e | ||
|
|
4244d37625 | ||
|
|
0b766095d4 | ||
|
|
a4f212a18f | ||
|
|
caafb73190 | ||
|
|
09482799c9 | ||
|
|
37f93d1760 | ||
|
|
725f2e5204 | ||
|
|
967198fae0 | ||
|
|
43d57f6dcb | ||
|
|
6afa4db577 | ||
|
|
3b8c3fb29a | ||
|
|
921c3b0627 | ||
|
|
c0fadb45ab | ||
|
|
a1481fb179 | ||
|
|
987cd972d3 | ||
|
|
bdf25976a3 | ||
|
|
87c3aff4ce | ||
|
|
99350a957a | ||
|
|
319068dc7e | ||
|
|
cd18806c39 | ||
|
|
95b08b2023 | ||
|
|
0e70f76c86 | ||
|
|
4d414a2994 | ||
|
|
3d22772d4e | ||
|
|
0b381e2570 | ||
|
|
f2cc4311c5 | ||
|
|
e349671fdf | ||
|
|
01c02d5efa | ||
|
|
b62b1f3870 | ||
|
|
8844830859 | ||
|
|
0c51ee4b64 | ||
|
|
11920d5e31 | ||
|
|
848ea1eb63 | ||
|
|
a216519486 | ||
|
|
b04606c38e | ||
|
|
38072beea7 | ||
|
|
b843f1fa03 | ||
|
|
560d40e571 | ||
|
|
5f0b8161b7 | ||
|
|
062d482917 | ||
|
|
39693a27e3 | ||
|
|
7cd1eeac30 | ||
|
|
bafa473c8e | ||
|
|
750cf46b2e | ||
|
|
68885a4bbc | ||
|
|
bcc99a8904 | ||
|
|
59fbd98db3 | ||
|
|
b70ed425f1 | ||
|
|
45ef5811c8 | ||
|
|
3b137ac762 | ||
|
|
1ddb0caf73 | ||
|
|
ae4c6fe2dd | ||
|
|
b03fe438d0 | ||
|
|
db257af58e | ||
|
|
735368c71b | ||
|
|
9e04e3679b | ||
|
|
43b8414727 | ||
|
|
5a00187147 | ||
|
|
cb525c7c84 | ||
|
|
d88420dd03 | ||
|
|
b9a983f8e0 | ||
|
|
42431ea7db | ||
|
|
f9459e4abb | ||
|
|
72f917d611 | ||
|
|
9fd1d19e93 | ||
|
|
062af1ac08 | ||
|
|
41bd76e091 | ||
|
|
cfd3f4b199 | ||
|
|
79d38f9597 | ||
|
|
b3866559e1 | ||
|
|
4d186baa35 | ||
|
|
8ed3d5f3db | ||
|
|
f0c8f39b6d | ||
|
|
431db8fc9b | ||
|
|
ba252c5356 | ||
|
|
a2812c39c0 | ||
|
|
0490758820 | ||
|
|
7f56824b42 | ||
|
|
627da3a2bc | ||
|
|
9b36a5c8a6 | ||
|
|
c1cf2be533 | ||
|
|
e6b69042de | ||
|
|
109650faf3 | ||
|
|
e54eaab842 | ||
|
|
43b6297b5d | ||
|
|
c20f4f5adf | ||
|
|
dc1f222cd2 | ||
|
|
c2b687212c | ||
|
|
849913276d | ||
|
|
23579c1e4a | ||
|
|
e031161fd4 | ||
|
|
4800ee6c0a | ||
|
|
d3a7fef9b0 | ||
|
|
40822fe77a | ||
|
|
837b670213 | ||
|
|
57ce69f3fb | ||
|
|
be022c4894 | ||
|
|
8a366964bb | ||
|
|
ee86b68470 | ||
|
|
60352307aa | ||
|
|
3ebd2f746f | ||
|
|
1c1a65b637 | ||
|
|
010e60d029 | ||
|
|
7a25568861 | ||
|
|
5f4f913661 | ||
|
|
ccd0e34a53 | ||
|
|
72f1ffccd3 | ||
|
|
ea7a52945f | ||
|
|
89d4d1351a | ||
|
|
b757c91d93 | ||
|
|
27203d7a4d | ||
|
|
9ad4e18ac5 | ||
|
|
fcdc8f3ce7 | ||
|
|
78b994b84a | ||
|
|
58bfc677e2 | ||
|
|
7d17285a0c | ||
|
|
e9eb00a0d4 | ||
|
|
48d07af574 | ||
|
|
2fc62efd88 | ||
|
|
be516d75bd | ||
|
|
951d5fde85 | ||
|
|
1389abc052 | ||
|
|
19ad67a77f | ||
|
|
641f308344 | ||
|
|
9f097fa4d5 | ||
|
|
5ad362c52b | ||
|
|
614f238a61 | ||
|
|
dec91950bc | ||
|
|
6cef9c23f0 | ||
|
|
3f568bf136 | ||
|
|
5484b421ce | ||
|
|
02f21e07d3 | ||
|
|
fff1f23a83 | ||
|
|
a056ec0d38 | ||
|
|
2eb9e5dde3 | ||
|
|
627d2a4701 | ||
|
|
76895fe86d | ||
|
|
64c3c85780 | ||
|
|
7288348857 | ||
|
|
62e73299b1 | ||
|
|
fe76c41ed8 | ||
|
|
1a92edf8be | ||
|
|
b63b606a4e | ||
|
|
8e2ef3d22b | ||
|
|
c6c4a32283 | ||
|
|
b70b3b158e | ||
|
|
3d59ab8108 | ||
|
|
b6c3089510 | ||
|
|
bd92aac280 | ||
|
|
5299e802e9 | ||
|
|
8e5a57d7dd | ||
|
|
beaa324fb6 | ||
|
|
79e64fe206 | ||
|
|
93f525e3fe | ||
|
|
aacb803c64 | ||
|
|
8a0665b222 | ||
|
|
20e41a7f73 | ||
|
|
93a1699a35 | ||
|
|
c33c07e4af | ||
|
|
c7484d0cc9 | ||
|
|
fb85a7bb35 | ||
|
|
42ff9a4d34 | ||
|
|
005e9eae7c | ||
|
|
3e325debcc | ||
|
|
a221de9a2b | ||
|
|
32b0cc1865 | ||
|
|
bbf85f8a12 | ||
|
|
67a0172b28 | ||
|
|
fb19d4d45b | ||
|
|
a156b1af14 | ||
|
|
a604b4943c | ||
|
|
3f0b6435d9 | ||
|
|
e0f029e2cb | ||
|
|
89d3fd5fab | ||
|
|
a38b00be6b | ||
|
|
0e8d52b591 | ||
|
|
298c77740d | ||
|
|
c681aae8ee | ||
|
|
faef98b089 | ||
|
|
84a3e0a30b | ||
|
|
69bd553ce0 | ||
|
|
fd0c0f8975 | ||
|
|
860ceb06b4 | ||
|
|
ecf501bf72 | ||
|
|
81a2ed1e25 | ||
|
|
76ab28338a | ||
|
|
9a56c9630f | ||
|
|
53b9497c18 | ||
|
|
750b16b6ee | ||
|
|
0ee3e0779a | ||
|
|
333c2d9299 | ||
|
|
ad37ff5048 | ||
|
|
33f86f3bde | ||
|
|
8acb969a49 | ||
|
|
b74b5933b8 | ||
|
|
681c556b7e | ||
|
|
1746684e52 | ||
|
|
0b93d06555 | ||
|
|
8a8b8c7c27 | ||
|
|
6b6577006d | ||
|
|
23ee5e81c9 | ||
|
|
483f55e4b1 | ||
|
|
1bb1bc2553 | ||
|
|
a4e4e36f94 | ||
|
|
6849415812 | ||
|
|
86f6cb038e | ||
|
|
7480a1d6ce | ||
|
|
3cd10117dd | ||
|
|
0caf19d390 | ||
|
|
5c14ebb049 | ||
|
|
9717a736b1 | ||
|
|
9c9ab50d1a | ||
|
|
d4bcb8174e | ||
|
|
9e7fe773bd | ||
|
|
aca18fab0f | ||
|
|
691de01b79 | ||
|
|
3383f15142 | ||
|
|
84c1593889 | ||
|
|
3c80fa1e33 | ||
|
|
06b16a1deb | ||
|
|
4c4246fb09 | ||
|
|
364be1e9f6 | ||
|
|
f959ed71aa | ||
|
|
5c4326c302 | ||
|
|
125fc3a622 | ||
|
|
6b9e785db3 | ||
|
|
25d34e9a43 | ||
|
|
457d4aa1dc | ||
|
|
ff0c0992ff | ||
|
|
d379e012c4 | ||
|
|
151fff26fd | ||
|
|
3d0d561215 | ||
|
|
22d586ed7b | ||
|
|
6dc19b29e8 | ||
|
|
50975a87d4 | ||
|
|
ce721d9f0f | ||
|
|
20510a33f7 | ||
|
|
3abd9c8763 | ||
|
|
e9eff7420b | ||
|
|
64c250c9d8 | ||
|
|
8047f82bfd | ||
|
|
af6467fb3d | ||
|
|
3ff1664aec | ||
|
|
34ea2b44b8 | ||
|
|
6c8d851109 | ||
|
|
d678299a74 | ||
|
|
7aed0db2b6 | ||
|
|
0355524345 | ||
|
|
0a43e4672e | ||
|
|
71e0ccdfec | ||
|
|
1df33ac3c8 | ||
|
|
7334090ac1 | ||
|
|
6b0f044198 | ||
|
|
ddf54c9cf8 | ||
|
|
7c64e184e2 | ||
|
|
a904db033c | ||
|
|
b234856b02 | ||
|
|
89d51d2afc | ||
|
|
37cb9678e9 | ||
|
|
0500ff333a | ||
|
|
08528510ef | ||
|
|
ddbd03dc1e | ||
|
|
ade87f378a | ||
|
|
4db14b905f | ||
|
|
b669b31451 | ||
|
|
1cb2b62f81 | ||
|
|
e5828713cf | ||
|
|
d10cb84068 | ||
|
|
4222f8516f | ||
|
|
7f998c7611 | ||
|
|
db46000337 | ||
|
|
1aac8d8041 | ||
|
|
c59c8e05f7 | ||
|
|
4942d0a629 | ||
|
|
873b7715f4 | ||
|
|
98e7ed6920 | ||
|
|
046f5e645e | ||
|
|
f5e5a7094c | ||
|
|
154125fee6 | ||
|
|
9f8e960ebe | ||
|
|
4179b0be0a | ||
|
|
28bafa38db | ||
|
|
b07552565e | ||
|
|
c4427471d2 | ||
|
|
08f81c6784 | ||
|
|
a471e98aca | ||
|
|
75a8fcc8a0 | ||
|
|
46ef76c168 | ||
|
|
66637446c9 | ||
|
|
21efeb888a | ||
|
|
a4ee8b5322 | ||
|
|
36519ac47e | ||
|
|
3f514fceca | ||
|
|
c2249fdfac | ||
|
|
c610719a44 | ||
|
|
36a6c2461a | ||
|
|
c29f22c39e | ||
|
|
30d3062944 | ||
|
|
69ba75abf4 | ||
|
|
e4d486fec5 | ||
|
|
f242144dcf | ||
|
|
02dee2d664 | ||
|
|
a3dd2c3069 | ||
|
|
a23425e8aa | ||
|
|
be79ddc9a3 | ||
|
|
7d71015e8c | ||
|
|
ad54549b51 | ||
|
|
6cf032a164 | ||
|
|
6390d796ac | ||
|
|
98b8411905 | ||
|
|
ddf1029afa | ||
|
|
1effbc5cc9 | ||
|
|
414b645e9f | ||
|
|
398c76f496 | ||
|
|
1bc456dd95 | ||
|
|
2e8421884e | ||
|
|
70d9b193ac | ||
|
|
b49c11004a | ||
|
|
34843eea90 | ||
|
|
2d6d7f31e8 | ||
|
|
7a24cbff1c | ||
|
|
1e7eb2cf1c | ||
|
|
361256e016 | ||
|
|
8838dbd003 | ||
|
|
13a95e1f2b | ||
|
|
1aaa451a3e | ||
|
|
cbba81e54d | ||
|
|
370868dfac | ||
|
|
77f692aae2 | ||
|
|
9318e205ea | ||
|
|
ebcc717c19 | ||
|
|
4c16b564ee | ||
|
|
e2283d1453 | ||
|
|
d891801c5a | ||
|
|
de75386944 | ||
|
|
82dc37de50 | ||
|
|
b6fa7f62dc | ||
|
|
f9e0a95c5e | ||
|
|
b2c6e12647 | ||
|
|
caffb83780 | ||
|
|
8882cb5479 | ||
|
|
75dace2dee | ||
|
|
ad6487d042 | ||
|
|
a91604e8ab | ||
|
|
c364f7c643 | ||
|
|
53435ba184 | ||
|
|
25f8d5519b | ||
|
|
2e4fef6c66 | ||
|
|
80b2b7dc00 | ||
|
|
8585cd8e21 | ||
|
|
9fa2a7eeea | ||
|
|
2d1f74228d | ||
|
|
3d6f7aa0e1 | ||
|
|
3dea60366a | ||
|
|
d4d9a1df4c | ||
|
|
7d6975fd31 | ||
|
|
08be52ed17 | ||
|
|
682a7700c2 | ||
|
|
9d87009216 | ||
|
|
ef86838f62 | ||
|
|
35468233f8 | ||
|
|
26e229867d | ||
|
|
3a1578b3c6 | ||
|
|
d5e3d2cbbc | ||
|
|
c095248176 | ||
|
|
44601c8954 | ||
|
|
135dbb8f07 | ||
|
|
c95682a0c7 | ||
|
|
d177b9f7fa | ||
|
|
9b57615d94 | ||
|
|
c03f3eacd1 | ||
|
|
a26e395932 | ||
|
|
0870b87c96 | ||
|
|
b52a44a7dd | ||
|
|
0a290aafef | ||
|
|
9014d4c410 | ||
|
|
60e58b4f5f | ||
|
|
620e74a6aa | ||
|
|
efa287ed35 | ||
|
|
a24eb9d9b0 | ||
|
|
bd3dab8aae | ||
|
|
4fe1ebaa5b | ||
|
|
c5e944744b | ||
|
|
0c396181f7 | ||
|
|
0034474219 | ||
|
|
8136ad8287 | ||
|
|
681940d466 | ||
|
|
16488506e8 | ||
|
|
122fccc041 | ||
|
|
9d0ad35403 | ||
|
|
f9ec97e026 | ||
|
|
95495a2647 | ||
|
|
e3310a605c | ||
|
|
b55719bf28 | ||
|
|
b957b51279 | ||
|
|
90bcfab369 | ||
|
|
f8a8e30641 | ||
|
|
25cb98e7a7 | ||
|
|
03e1bb7cf9 | ||
|
|
85dbb24f3a | ||
|
|
d817635782 | ||
|
|
2f4f237810 | ||
|
|
5ac94d810f | ||
|
|
39dc46dc25 | ||
|
|
0d9cf725f7 | ||
|
|
e55dbead5b | ||
|
|
7d046e5b30 | ||
|
|
8b4693cf66 | ||
|
|
a1172c9a82 | ||
|
|
1ed2bd33f0 | ||
|
|
4c159bd0ba | ||
|
|
050654b2a9 | ||
|
|
61b261e1b2 | ||
|
|
017b010206 | ||
|
|
00f5189f58 | ||
|
|
4a8309ed1f | ||
|
|
76cfc31a1d | ||
|
|
d9ec434699 | ||
|
|
239f3c40be | ||
|
|
09c8c6e670 | ||
|
|
7e4ad01c94 | ||
|
|
ed98e269ef | ||
|
|
b47d63334f | ||
|
|
5e2a3a5aea | ||
|
|
1a7eb21fc7 | ||
|
|
834a51cdc9 | ||
|
|
1b69d99c06 | ||
|
|
ad189933c6 | ||
|
|
9d86ff32de | ||
|
|
278bb57a58 | ||
|
|
0ba494e0ba | ||
|
|
8b247054bb | ||
|
|
7c5c8e4e0d | ||
|
|
ad106a27f3 | ||
|
|
9d6f61b49e | ||
|
|
02368954a0 | ||
|
|
b477a35a01 | ||
|
|
16622887de | ||
|
|
9059d1fb17 | ||
|
|
df2b008d82 | ||
|
|
0da871efd0 | ||
|
|
1c55349f81 | ||
|
|
9309fa1e81 | ||
|
|
5996189f91 | ||
|
|
bd2b984bfb | ||
|
|
194409a117 | ||
|
|
27978b216d | ||
|
|
c38fa77ce6 | ||
|
|
3eb49f7422 | ||
|
|
1989d615d2 | ||
|
|
239412d265 | ||
|
|
375a419a9e | ||
|
|
875c8ab424 | ||
|
|
c9bfc810ce | ||
|
|
46ecb16949 | ||
|
|
d6a785b645 | ||
|
|
79db828a01 |
@@ -18,3 +18,7 @@ ENV/
|
|||||||
README*.md
|
README*.md
|
||||||
dashboard/
|
dashboard/
|
||||||
data/
|
data/
|
||||||
|
changelogs/
|
||||||
|
tests/
|
||||||
|
.ruff_cache/
|
||||||
|
.astrbot
|
||||||
15
.github/FUNDING.yml
vendored
Normal file
15
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# These are supported funding model platforms
|
||||||
|
|
||||||
|
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||||
|
patreon: # Replace with a single Patreon username
|
||||||
|
open_collective: astrbot
|
||||||
|
ko_fi: # Replace with a single Ko-fi username
|
||||||
|
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||||
|
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||||
|
liberapay: # Replace with a single Liberapay username
|
||||||
|
issuehunt: # Replace with a single IssueHunt username
|
||||||
|
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||||
|
polar: # Replace with a single Polar username
|
||||||
|
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||||
|
thanks_dev: # Replace with a single thanks.dev username
|
||||||
|
custom: ['https://afdian.com/a/astrbot_team']
|
||||||
9
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
9
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -6,7 +6,7 @@ body:
|
|||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: |
|
||||||
欢迎发布插件到插件市场!
|
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
@@ -22,9 +22,10 @@ body:
|
|||||||
插件名:
|
插件名:
|
||||||
插件作者:
|
插件作者:
|
||||||
插件简介:
|
插件简介:
|
||||||
标签: (可选)
|
支持的消息平台:(必填,如 QQ、微信、飞书)
|
||||||
社交链接: (可选, 将会在插件市场作者名称上作为可点击的链接)
|
标签:(可选)
|
||||||
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。
|
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
|
||||||
|
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
|
||||||
|
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,5 +1,5 @@
|
|||||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||||
修复了 #XYZ
|
解决了 #XYZ
|
||||||
|
|
||||||
### Motivation
|
### Motivation
|
||||||
|
|
||||||
@@ -8,3 +8,12 @@
|
|||||||
### Modifications
|
### Modifications
|
||||||
|
|
||||||
<!--简单解释你的改动-->
|
<!--简单解释你的改动-->
|
||||||
|
|
||||||
|
### Check
|
||||||
|
|
||||||
|
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||||
|
|
||||||
|
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||||
|
- [ ] 👀 我的更改经过良好的测试
|
||||||
|
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||||
|
- [ ] 😮 我的更改没有引入恶意代码
|
||||||
|
|||||||
31
.github/workflows/auto_release.yml
vendored
31
.github/workflows/auto_release.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
name: Auto Release
|
name: Auto Release
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build-and-publish-to-github-release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@@ -28,8 +28,35 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Create Release
|
- name: Create GitHub Release
|
||||||
uses: ncipollo/release-action@v1
|
uses: ncipollo/release-action@v1
|
||||||
with:
|
with:
|
||||||
bodyFile: ${{ env.changelog }}
|
bodyFile: ${{ env.changelog }}
|
||||||
artifacts: "dashboard/dist.zip"
|
artifacts: "dashboard/dist.zip"
|
||||||
|
|
||||||
|
build-and-publish-to-pypi:
|
||||||
|
# 构建并发布到 PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: build-and-publish-to-github-release
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
run: |
|
||||||
|
python -m pip install uv
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
uv build
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||||
|
run: |
|
||||||
|
uv publish
|
||||||
|
|||||||
31
.github/workflows/dashboard_ci.yml
vendored
Normal file
31
.github/workflows/dashboard_ci.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
name: AstrBot Dashboard CI
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: npm install, build
|
||||||
|
run: |
|
||||||
|
cd dashboard
|
||||||
|
npm install
|
||||||
|
npm run build
|
||||||
|
|
||||||
|
- name: Inject Commit SHA
|
||||||
|
id: get_sha
|
||||||
|
run: |
|
||||||
|
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||||
|
mkdir -p dashboard/dist/assets
|
||||||
|
echo $COMMIT_SHA > dashboard/dist/assets/version
|
||||||
|
|
||||||
|
- name: Archive production artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: dist-without-markdown
|
||||||
|
path: |
|
||||||
|
dashboard/dist
|
||||||
|
!dist/**/*.md
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,6 +1,8 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
botpy.log
|
botpy.log
|
||||||
.vscode
|
.vscode
|
||||||
|
.venv*
|
||||||
|
.idea
|
||||||
data_v2.db
|
data_v2.db
|
||||||
data_v3.db
|
data_v3.db
|
||||||
configs/session
|
configs/session
|
||||||
@@ -26,3 +28,6 @@ venv/*
|
|||||||
packages/python_interpreter/workplace
|
packages/python_interpreter/workplace
|
||||||
.venv/*
|
.venv/*
|
||||||
.conda/
|
.conda/
|
||||||
|
.idea
|
||||||
|
pytest.ini
|
||||||
|
.astrbot
|
||||||
@@ -7,7 +7,7 @@ ci:
|
|||||||
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.9
|
rev: v0.11.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.10
|
||||||
17
Dockerfile
17
Dockerfile
@@ -4,19 +4,32 @@ WORKDIR /AstrBot
|
|||||||
COPY . /AstrBot/
|
COPY . /AstrBot/
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
nodejs \
|
||||||
|
npm \
|
||||||
gcc \
|
gcc \
|
||||||
build-essential \
|
build-essential \
|
||||||
python3-dev \
|
python3-dev \
|
||||||
libffi-dev \
|
libffi-dev \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
ca-certificates \
|
||||||
|
bash \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN python -m pip install -r requirements.txt --no-cache-dir
|
RUN python -m pip install uv
|
||||||
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
|
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||||
|
|
||||||
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
|
# 释出 ffmpeg
|
||||||
|
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||||
|
|
||||||
|
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||||
|
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||||
|
|
||||||
EXPOSE 6185
|
EXPOSE 6185
|
||||||
EXPOSE 6186
|
EXPOSE 6186
|
||||||
|
|
||||||
CMD [ "python", "main.py" ]
|
CMD [ "python", "main.py" ]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
35
Dockerfile_with_node
Normal file
35
Dockerfile_with_node
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /AstrBot
|
||||||
|
|
||||||
|
COPY . /AstrBot/
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
build-essential \
|
||||||
|
python3-dev \
|
||||||
|
libffi-dev \
|
||||||
|
libssl-dev \
|
||||||
|
curl \
|
||||||
|
unzip \
|
||||||
|
ca-certificates \
|
||||||
|
bash \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Installation of Node.js
|
||||||
|
ENV NVM_DIR="/root/.nvm"
|
||||||
|
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||||
|
. "$NVM_DIR/nvm.sh" && \
|
||||||
|
nvm install 22 && \
|
||||||
|
nvm use 22
|
||||||
|
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
||||||
|
|
||||||
|
RUN python -m pip install uv
|
||||||
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
|
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
||||||
|
|
||||||
|
EXPOSE 6185
|
||||||
|
EXPOSE 6186
|
||||||
|
|
||||||
|
CMD ["python", "main.py"]
|
||||||
145
README.md
145
README.md
@@ -1,6 +1,6 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -10,14 +10,14 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||

|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||
[](https://codecov.io/gh/Soulter/AstrBot)
|

|
||||||
[](https://gitcode.com/Soulter/AstrBot)
|

|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
@@ -27,19 +27,34 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
|
|
||||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||||
|
|
||||||
|
|
||||||
|
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||||
|
-->
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
>
|
||||||
|
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,`v3.5.10` 已经支持接入 WeChatPadPro 替换 gewechat 方式。详见文档 [WeChatPadPro](https://astrbot.app/deploy/platform/wechat/wechatpadpro.html)
|
||||||
|
|
||||||
|
## ✨ 近期更新
|
||||||
|
|
||||||
|
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||||
|
|
||||||
## ✨ 主要功能
|
## ✨ 主要功能
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||||
|
|
||||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||||
>
|
>
|
||||||
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
|
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||||
|
|
||||||
## ✨ 使用方式
|
## ✨ 使用方式
|
||||||
|
|
||||||
@@ -49,30 +64,48 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
|
|
||||||
#### Windows 一键安装器部署
|
#### Windows 一键安装器部署
|
||||||
|
|
||||||
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||||
|
|
||||||
#### Replit 部署
|
#### 宝塔面板部署
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||||
|
|
||||||
#### CasaOS 部署
|
#### CasaOS 部署
|
||||||
|
|
||||||
社区贡献的部署方式。
|
社区贡献的部署方式。
|
||||||
|
|
||||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
||||||
|
|
||||||
#### 手动部署
|
#### 手动部署
|
||||||
|
|
||||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
> 推荐使用 `uv`。
|
||||||
|
|
||||||
## 🚀 路线图
|
首先,安装 uv:
|
||||||
|
|
||||||
### 垂类功能
|
```bash
|
||||||
|
pip install uv
|
||||||
|
```
|
||||||
|
|
||||||
1. 更好的上下文管理:限制 token 总数、对话上下文总结
|
通过 Git Clone 安装 AstrBot:
|
||||||
3. AstrBot in Minecraft
|
|
||||||
|
|
||||||
### 横功能
|
```bash
|
||||||
|
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或者,直接通过 uvx 安装 AstrBot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir astrbot && cd astrbot
|
||||||
|
uvx astrbot init
|
||||||
|
# uvx astrbot run
|
||||||
|
```
|
||||||
|
|
||||||
|
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||||
|
|
||||||
|
#### Replit 部署
|
||||||
|
|
||||||
|
[](https://repl.it/github/Soulter/AstrBot)
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
@@ -80,10 +113,12 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
| -------- | ------- | ------- | ------ |
|
| -------- | ------- | ------- | ------ |
|
||||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
| 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
| Telegram | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
| 企业微信 | ✔ | 私聊 | 文字、图片、语音 |
|
||||||
| 飞书 | ✔ | 群聊 | 文字、图片 |
|
| 微信客服 | ✔ | 私聊 | 文字、图片 |
|
||||||
|
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
|
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||||
| Discord | 🚧 | 计划内 | - |
|
| Discord | 🚧 | 计划内 | - |
|
||||||
| WhatsApp | 🚧 | 计划内 | - |
|
| WhatsApp | 🚧 | 计划内 | - |
|
||||||
@@ -93,20 +128,26 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
|||||||
|
|
||||||
| 名称 | 支持性 | 类型 | 备注 |
|
| 名称 | 支持性 | 类型 | 备注 |
|
||||||
| -------- | ------- | ------- | ------- |
|
| -------- | ------- | ------- | ------- |
|
||||||
| OpenAI API | ✔ | 文本生成 | 同时也支持 DeepSeek、Google Gemini、GLM(智谱)、Moonshot(月之暗面)、阿里云百炼、硅基流动、xAI 等所有兼容 OpenAI API 的服务 |
|
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||||
| Claude API | ✔ | 文本生成 | |
|
| Claude API | ✔ | 文本生成 | |
|
||||||
| Google Gemini API | ✔ | 文本生成 | |
|
| Google Gemini API | ✔ | 文本生成 | |
|
||||||
| Dify | ✔ | LLMOps | |
|
| Dify | ✔ | LLMOps | |
|
||||||
| DashScope(阿里云百炼应用) | ✔ | LLMOps | |
|
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||||
|
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||||
|
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||||
| OneAPI | ✔ | LLM 分发系统 | |
|
| OneAPI | ✔ | LLM 分发系统 | |
|
||||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||||
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||||
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||||
|
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||||
|
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||||
|
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||||
|
|
||||||
|
|
||||||
## ❤️ 贡献
|
## ❤️ 贡献
|
||||||
|
|
||||||
@@ -134,38 +175,45 @@ pre-commit install
|
|||||||
|
|
||||||
## ✨ Demo
|
## ✨ Demo
|
||||||
|
|
||||||
> [!NOTE]
|
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
|
||||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
|
||||||
|
|
||||||
<div align='center'>
|
<div align='center'>
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||||
|
|
||||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||||
|
|
||||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
|
||||||
|
|
||||||
_✨ 自然语言待办事项 ✨_
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||||
|
|
||||||
_✨ 插件系统——部分插件展示 ✨_
|
_✨ 插件系统——部分插件展示 ✨_
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
||||||
|
|
||||||
_✨ 管理面板 ✨_
|
_✨ WebUI ✨_
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
_✨ 内置 Web Chat,在线与机器人交互 ✨_
|
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
## ❤️ Special Thanks
|
||||||
|
|
||||||
|
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||||
|
|
||||||
|
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
此外,本项目的诞生离不开以下开源项目:
|
||||||
|
|
||||||
|
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
|
||||||
|
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
@@ -183,16 +231,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
|||||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||||
3. Please ensure compliance with local laws and regulations when using this project.
|
3. Please ensure compliance with local laws and regulations when using this project.
|
||||||
|
|
||||||
<!-- ## ✨ ATRI [Beta 测试]
|
|
||||||
|
|
||||||
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
|
||||||
|
|
||||||
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
|
|
||||||
2. 长期记忆
|
|
||||||
3. 表情包理解与回复
|
|
||||||
4. TTS
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ AstrBot is a loosely coupled, asynchronous chatbot and development framework tha
|
|||||||
|
|
||||||
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
||||||
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
||||||
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://astrbot.app/others/dify.html) for easy access to Dify assistants/knowledge bases/workflows.
|
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
|
||||||
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
||||||
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
||||||
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
|||||||
|
|
||||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||||
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from astrbot.core.platform import (
|
|||||||
MessageMember,
|
MessageMember,
|
||||||
MessageType,
|
MessageType,
|
||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
|
Group,
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.platform.register import register_platform_adapter
|
from astrbot.core.platform.register import register_platform_adapter
|
||||||
@@ -18,4 +19,5 @@ __all__ = [
|
|||||||
"MessageType",
|
"MessageType",
|
||||||
"PlatformMetadata",
|
"PlatformMetadata",
|
||||||
"register_platform_adapter",
|
"register_platform_adapter",
|
||||||
|
"Group",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||||
from astrbot.core.provider.entites import (
|
from astrbot.core.provider.entities import (
|
||||||
ProviderRequest,
|
ProviderRequest,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
ProviderMetaData,
|
ProviderMetaData,
|
||||||
|
|||||||
@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
|
|||||||
register_star as register, # 注册插件(Star)
|
register_star as register, # 注册插件(Star)
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.star import Context, Star
|
from astrbot.core.star import Context, Star, StarTools
|
||||||
from astrbot.core.star.config import *
|
from astrbot.core.star.config import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["register", "Context", "Star", "StarTools"]
|
||||||
"register",
|
|
||||||
"Context",
|
|
||||||
"Star",
|
|
||||||
]
|
|
||||||
|
|||||||
1
astrbot/cli/__init__.py
Normal file
1
astrbot/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = "3.5.8"
|
||||||
59
astrbot/cli/__main__.py
Normal file
59
astrbot/cli/__main__.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
AstrBot CLI入口
|
||||||
|
"""
|
||||||
|
|
||||||
|
import click
|
||||||
|
import sys
|
||||||
|
from . import __version__
|
||||||
|
from .commands import init, run, plug, conf
|
||||||
|
|
||||||
|
logo_tmpl = r"""
|
||||||
|
___ _______.___________..______ .______ ______ .___________.
|
||||||
|
/ \ / | || _ \ | _ \ / __ \ | |
|
||||||
|
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||||
|
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||||
|
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||||
|
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.version_option(__version__, prog_name="AstrBot")
|
||||||
|
def cli() -> None:
|
||||||
|
"""The AstrBot CLI"""
|
||||||
|
click.echo(logo_tmpl)
|
||||||
|
click.echo("Welcome to AstrBot CLI!")
|
||||||
|
click.echo(f"AstrBot CLI version: {__version__}")
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("command_name", required=False, type=str)
|
||||||
|
def help(command_name: str | None) -> None:
|
||||||
|
"""显示命令的帮助信息
|
||||||
|
|
||||||
|
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
|
||||||
|
否则,显示通用帮助信息。
|
||||||
|
"""
|
||||||
|
ctx = click.get_current_context()
|
||||||
|
if command_name:
|
||||||
|
# 查找指定命令
|
||||||
|
command = cli.get_command(ctx, command_name)
|
||||||
|
if command:
|
||||||
|
# 显示特定命令的帮助信息
|
||||||
|
click.echo(command.get_help(ctx))
|
||||||
|
else:
|
||||||
|
click.echo(f"Unknown command: {command_name}")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# 显示通用帮助信息
|
||||||
|
click.echo(cli.get_help(ctx))
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_command(init)
|
||||||
|
cli.add_command(run)
|
||||||
|
cli.add_command(help)
|
||||||
|
cli.add_command(plug)
|
||||||
|
cli.add_command(conf)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
6
astrbot/cli/commands/__init__.py
Normal file
6
astrbot/cli/commands/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .cmd_init import init
|
||||||
|
from .cmd_run import run
|
||||||
|
from .cmd_plug import plug
|
||||||
|
from .cmd_conf import conf
|
||||||
|
|
||||||
|
__all__ = ["init", "run", "plug", "conf"]
|
||||||
206
astrbot/cli/commands/cmd_conf.py
Normal file
206
astrbot/cli/commands/cmd_conf.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
import json
|
||||||
|
import click
|
||||||
|
import hashlib
|
||||||
|
import zoneinfo
|
||||||
|
from typing import Any, Callable
|
||||||
|
from ..utils import get_astrbot_root, check_astrbot_root
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_log_level(value: str) -> str:
|
||||||
|
"""验证日志级别"""
|
||||||
|
value = value.upper()
|
||||||
|
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||||
|
raise click.ClickException(
|
||||||
|
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_dashboard_port(value: str) -> int:
|
||||||
|
"""验证 Dashboard 端口"""
|
||||||
|
try:
|
||||||
|
port = int(value)
|
||||||
|
if port < 1 or port > 65535:
|
||||||
|
raise click.ClickException("端口必须在 1-65535 范围内")
|
||||||
|
return port
|
||||||
|
except ValueError:
|
||||||
|
raise click.ClickException("端口必须是数字")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_dashboard_username(value: str) -> str:
|
||||||
|
"""验证 Dashboard 用户名"""
|
||||||
|
if not value:
|
||||||
|
raise click.ClickException("用户名不能为空")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_dashboard_password(value: str) -> str:
|
||||||
|
"""验证 Dashboard 密码"""
|
||||||
|
if not value:
|
||||||
|
raise click.ClickException("密码不能为空")
|
||||||
|
return hashlib.md5(value.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_timezone(value: str) -> str:
|
||||||
|
"""验证时区"""
|
||||||
|
try:
|
||||||
|
zoneinfo.ZoneInfo(value)
|
||||||
|
except Exception:
|
||||||
|
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_callback_api_base(value: str) -> str:
|
||||||
|
"""验证回调接口基址"""
|
||||||
|
if not value.startswith("http://") and not value.startswith("https://"):
|
||||||
|
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# 可通过CLI设置的配置项,配置键到验证器函数的映射
|
||||||
|
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||||
|
"timezone": _validate_timezone,
|
||||||
|
"log_level": _validate_log_level,
|
||||||
|
"dashboard.port": _validate_dashboard_port,
|
||||||
|
"dashboard.username": _validate_dashboard_username,
|
||||||
|
"dashboard.password": _validate_dashboard_password,
|
||||||
|
"callback_api_base": _validate_callback_api_base,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config() -> dict[str, Any]:
|
||||||
|
"""加载或初始化配置文件"""
|
||||||
|
root = get_astrbot_root()
|
||||||
|
if not check_astrbot_root(root):
|
||||||
|
raise click.ClickException(
|
||||||
|
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_path = root / "data" / "cmd_config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||||
|
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8-sig",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _save_config(config: dict[str, Any]) -> None:
|
||||||
|
"""保存配置文件"""
|
||||||
|
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||||
|
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||||
|
"""设置嵌套字典中的值"""
|
||||||
|
parts = path.split(".")
|
||||||
|
for part in parts[:-1]:
|
||||||
|
if part not in obj:
|
||||||
|
obj[part] = {}
|
||||||
|
elif not isinstance(obj[part], dict):
|
||||||
|
raise click.ClickException(
|
||||||
|
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
||||||
|
)
|
||||||
|
obj = obj[part]
|
||||||
|
obj[parts[-1]] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||||
|
"""获取嵌套字典中的值"""
|
||||||
|
parts = path.split(".")
|
||||||
|
for part in parts:
|
||||||
|
obj = obj[part]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(name="conf")
|
||||||
|
def conf():
|
||||||
|
"""配置管理命令
|
||||||
|
|
||||||
|
支持的配置项:
|
||||||
|
|
||||||
|
- timezone: 时区设置 (例如: Asia/Shanghai)
|
||||||
|
|
||||||
|
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||||
|
|
||||||
|
- dashboard.port: Dashboard 端口
|
||||||
|
|
||||||
|
- dashboard.username: Dashboard 用户名
|
||||||
|
|
||||||
|
- dashboard.password: Dashboard 密码
|
||||||
|
|
||||||
|
- callback_api_base: 回调接口基址
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@conf.command(name="set")
|
||||||
|
@click.argument("key")
|
||||||
|
@click.argument("value")
|
||||||
|
def set_config(key: str, value: str):
|
||||||
|
"""设置配置项的值"""
|
||||||
|
if key not in CONFIG_VALIDATORS.keys():
|
||||||
|
raise click.ClickException(f"不支持的配置项: {key}")
|
||||||
|
|
||||||
|
config = _load_config()
|
||||||
|
|
||||||
|
try:
|
||||||
|
old_value = _get_nested_item(config, key)
|
||||||
|
validated_value = CONFIG_VALIDATORS[key](value)
|
||||||
|
_set_nested_item(config, key, validated_value)
|
||||||
|
_save_config(config)
|
||||||
|
|
||||||
|
click.echo(f"配置已更新: {key}")
|
||||||
|
if key == "dashboard.password":
|
||||||
|
click.echo(" 原值: ********")
|
||||||
|
click.echo(" 新值: ********")
|
||||||
|
else:
|
||||||
|
click.echo(f" 原值: {old_value}")
|
||||||
|
click.echo(f" 新值: {validated_value}")
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
raise click.ClickException(f"未知的配置项: {key}")
|
||||||
|
except Exception as e:
|
||||||
|
raise click.UsageError(f"设置配置失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@conf.command(name="get")
|
||||||
|
@click.argument("key", required=False)
|
||||||
|
def get_config(key: str = None):
|
||||||
|
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||||
|
config = _load_config()
|
||||||
|
|
||||||
|
if key:
|
||||||
|
if key not in CONFIG_VALIDATORS.keys():
|
||||||
|
raise click.ClickException(f"不支持的配置项: {key}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = _get_nested_item(config, key)
|
||||||
|
if key == "dashboard.password":
|
||||||
|
value = "********"
|
||||||
|
click.echo(f"{key}: {value}")
|
||||||
|
except KeyError:
|
||||||
|
raise click.ClickException(f"未知的配置项: {key}")
|
||||||
|
except Exception as e:
|
||||||
|
raise click.UsageError(f"获取配置失败: {str(e)}")
|
||||||
|
else:
|
||||||
|
click.echo("当前配置:")
|
||||||
|
for key in CONFIG_VALIDATORS.keys():
|
||||||
|
try:
|
||||||
|
value = (
|
||||||
|
"********"
|
||||||
|
if key == "dashboard.password"
|
||||||
|
else _get_nested_item(config, key)
|
||||||
|
)
|
||||||
|
click.echo(f" {key}: {value}")
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
pass
|
||||||
55
astrbot/cli/commands/cmd_init.py
Normal file
55
astrbot/cli/commands/cmd_init.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import click
|
||||||
|
from filelock import FileLock, Timeout
|
||||||
|
|
||||||
|
from ..utils import check_dashboard, get_astrbot_root
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_astrbot(astrbot_root) -> None:
|
||||||
|
"""执行 AstrBot 初始化逻辑"""
|
||||||
|
dot_astrbot = astrbot_root / ".astrbot"
|
||||||
|
|
||||||
|
if not dot_astrbot.exists():
|
||||||
|
click.echo(f"Current Directory: {astrbot_root}")
|
||||||
|
click.echo(
|
||||||
|
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||||
|
)
|
||||||
|
if click.confirm(
|
||||||
|
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||||
|
default=True,
|
||||||
|
abort=True,
|
||||||
|
):
|
||||||
|
dot_astrbot.touch()
|
||||||
|
click.echo(f"Created {dot_astrbot}")
|
||||||
|
|
||||||
|
paths = {
|
||||||
|
"data": astrbot_root / "data",
|
||||||
|
"config": astrbot_root / "data" / "config",
|
||||||
|
"plugins": astrbot_root / "data" / "plugins",
|
||||||
|
"temp": astrbot_root / "data" / "temp",
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, path in paths.items():
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
|
||||||
|
|
||||||
|
await check_dashboard(astrbot_root / "data")
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
def init() -> None:
|
||||||
|
"""初始化 AstrBot"""
|
||||||
|
click.echo("Initializing AstrBot...")
|
||||||
|
astrbot_root = get_astrbot_root()
|
||||||
|
lock_file = astrbot_root / "astrbot.lock"
|
||||||
|
lock = FileLock(lock_file, timeout=5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with lock.acquire():
|
||||||
|
asyncio.run(initialize_astrbot(astrbot_root))
|
||||||
|
except Timeout:
|
||||||
|
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise click.ClickException(f"初始化失败: {e!s}")
|
||||||
247
astrbot/cli/commands/cmd_plug.py
Normal file
247
astrbot/cli/commands/cmd_plug.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import click
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
get_git_repo,
|
||||||
|
build_plug_list,
|
||||||
|
manage_plugin,
|
||||||
|
PluginStatus,
|
||||||
|
check_astrbot_root,
|
||||||
|
get_astrbot_root,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def plug():
|
||||||
|
"""插件管理"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data_path() -> Path:
|
||||||
|
base = get_astrbot_root()
|
||||||
|
if not check_astrbot_root(base):
|
||||||
|
raise click.ClickException(
|
||||||
|
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||||
|
)
|
||||||
|
return (base / "data").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def display_plugins(plugins, title=None, color=None):
|
||||||
|
if title:
|
||||||
|
click.echo(click.style(title, fg=color, bold=True))
|
||||||
|
|
||||||
|
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
|
||||||
|
click.echo("-" * 85)
|
||||||
|
|
||||||
|
for p in plugins:
|
||||||
|
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||||
|
click.echo(
|
||||||
|
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||||
|
f"{p['author']:<15} {desc:<30}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@plug.command()
|
||||||
|
@click.argument("name")
|
||||||
|
def new(name: str):
|
||||||
|
"""创建新插件"""
|
||||||
|
base_path = _get_data_path()
|
||||||
|
plug_path = base_path / "plugins" / name
|
||||||
|
|
||||||
|
if plug_path.exists():
|
||||||
|
raise click.ClickException(f"插件 {name} 已存在")
|
||||||
|
|
||||||
|
author = click.prompt("请输入插件作者", type=str)
|
||||||
|
desc = click.prompt("请输入插件描述", type=str)
|
||||||
|
version = click.prompt("请输入插件版本", type=str)
|
||||||
|
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
|
||||||
|
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
|
||||||
|
repo = click.prompt("请输入插件仓库:", type=str)
|
||||||
|
if not repo.startswith("http"):
|
||||||
|
raise click.ClickException("仓库地址必须以 http 开头")
|
||||||
|
|
||||||
|
click.echo("下载插件模板...")
|
||||||
|
get_git_repo(
|
||||||
|
"https://github.com/Soulter/helloworld",
|
||||||
|
plug_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
click.echo("重写插件信息...")
|
||||||
|
# 重写 metadata.yaml
|
||||||
|
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
f"name: {name}\n"
|
||||||
|
f"desc: {desc}\n"
|
||||||
|
f"version: {version}\n"
|
||||||
|
f"author: {author}\n"
|
||||||
|
f"repo: {repo}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重写 README.md
|
||||||
|
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
|
||||||
|
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||||
|
|
||||||
|
# 重写 main.py
|
||||||
|
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
new_content = content.replace(
|
||||||
|
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
|
||||||
|
f'@register("{name}", "{author}", "{desc}", "{version}")',
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
|
||||||
|
f.write(new_content)
|
||||||
|
|
||||||
|
click.echo(f"插件 {name} 创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@plug.command()
|
||||||
|
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||||
|
def list(all: bool):
|
||||||
|
"""列出插件"""
|
||||||
|
base_path = _get_data_path()
|
||||||
|
plugins = build_plug_list(base_path / "plugins")
|
||||||
|
|
||||||
|
# 未发布的插件
|
||||||
|
not_published_plugins = [
|
||||||
|
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
|
||||||
|
]
|
||||||
|
if not_published_plugins:
|
||||||
|
display_plugins(not_published_plugins, "未发布的插件", "red")
|
||||||
|
|
||||||
|
# 需要更新的插件
|
||||||
|
need_update_plugins = [
|
||||||
|
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||||
|
]
|
||||||
|
if need_update_plugins:
|
||||||
|
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
|
||||||
|
|
||||||
|
# 已安装的插件
|
||||||
|
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
|
||||||
|
if installed_plugins:
|
||||||
|
display_plugins(installed_plugins, "已安装的插件", "green")
|
||||||
|
|
||||||
|
# 未安装的插件
|
||||||
|
not_installed_plugins = [
|
||||||
|
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
|
||||||
|
]
|
||||||
|
if not_installed_plugins and all:
|
||||||
|
display_plugins(not_installed_plugins, "未安装的插件", "blue")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not any([not_published_plugins, need_update_plugins, installed_plugins])
|
||||||
|
and not all
|
||||||
|
):
|
||||||
|
click.echo("未安装任何插件")
|
||||||
|
|
||||||
|
|
||||||
|
@plug.command()
|
||||||
|
@click.argument("name")
|
||||||
|
@click.option("--proxy", help="代理服务器地址")
|
||||||
|
def install(name: str, proxy: str | None):
|
||||||
|
"""安装插件"""
|
||||||
|
base_path = _get_data_path()
|
||||||
|
plug_path = base_path / "plugins"
|
||||||
|
plugins = build_plug_list(base_path / "plugins")
|
||||||
|
|
||||||
|
plugin = next(
|
||||||
|
(
|
||||||
|
p
|
||||||
|
for p in plugins
|
||||||
|
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not plugin:
|
||||||
|
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
|
||||||
|
|
||||||
|
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||||
|
|
||||||
|
|
||||||
|
@plug.command()
|
||||||
|
@click.argument("name")
|
||||||
|
def remove(name: str):
|
||||||
|
"""卸载插件"""
|
||||||
|
base_path = _get_data_path()
|
||||||
|
plugins = build_plug_list(base_path / "plugins")
|
||||||
|
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||||
|
|
||||||
|
if not plugin or not plugin.get("local_path"):
|
||||||
|
raise click.ClickException(f"插件 {name} 不存在或未安装")
|
||||||
|
|
||||||
|
plugin_path = plugin["local_path"]
|
||||||
|
|
||||||
|
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(plugin_path)
|
||||||
|
click.echo(f"插件 {name} 已卸载")
|
||||||
|
except Exception as e:
|
||||||
|
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@plug.command()
|
||||||
|
@click.argument("name", required=False)
|
||||||
|
@click.option("--proxy", help="Github代理地址")
|
||||||
|
def update(name: str, proxy: str | None):
|
||||||
|
"""更新插件"""
|
||||||
|
base_path = _get_data_path()
|
||||||
|
plug_path = base_path / "plugins"
|
||||||
|
plugins = build_plug_list(base_path / "plugins")
|
||||||
|
|
||||||
|
if name:
|
||||||
|
plugin = next(
|
||||||
|
(
|
||||||
|
p
|
||||||
|
for p in plugins
|
||||||
|
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not plugin:
|
||||||
|
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
|
||||||
|
|
||||||
|
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||||
|
else:
|
||||||
|
need_update_plugins = [
|
||||||
|
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||||
|
]
|
||||||
|
|
||||||
|
if not need_update_plugins:
|
||||||
|
click.echo("没有需要更新的插件")
|
||||||
|
return
|
||||||
|
|
||||||
|
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
|
||||||
|
for plugin in need_update_plugins:
|
||||||
|
plugin_name = plugin["name"]
|
||||||
|
click.echo(f"正在更新插件 {plugin_name}...")
|
||||||
|
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||||
|
|
||||||
|
|
||||||
|
@plug.command()
|
||||||
|
@click.argument("query")
|
||||||
|
def search(query: str):
|
||||||
|
"""搜索插件"""
|
||||||
|
base_path = _get_data_path()
|
||||||
|
plugins = build_plug_list(base_path / "plugins")
|
||||||
|
|
||||||
|
matched_plugins = [
|
||||||
|
p
|
||||||
|
for p in plugins
|
||||||
|
if query.lower() in p["name"].lower()
|
||||||
|
or query.lower() in p["desc"].lower()
|
||||||
|
or query.lower() in p["author"].lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
if not matched_plugins:
|
||||||
|
click.echo(f"未找到匹配 '{query}' 的插件")
|
||||||
|
return
|
||||||
|
|
||||||
|
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
|
||||||
63
astrbot/cli/commands/cmd_run.py
Normal file
63
astrbot/cli/commands/cmd_run.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import click
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from filelock import FileLock, Timeout
|
||||||
|
|
||||||
|
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
||||||
|
|
||||||
|
|
||||||
|
async def run_astrbot(astrbot_root: Path):
|
||||||
|
"""运行 AstrBot"""
|
||||||
|
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
||||||
|
from astrbot.core.initial_loader import InitialLoader
|
||||||
|
|
||||||
|
await check_dashboard(astrbot_root / "data")
|
||||||
|
|
||||||
|
log_broker = LogBroker()
|
||||||
|
LogManager.set_queue_handler(logger, log_broker)
|
||||||
|
db = db_helper
|
||||||
|
|
||||||
|
core_lifecycle = InitialLoader(db, log_broker)
|
||||||
|
|
||||||
|
await core_lifecycle.start()
|
||||||
|
|
||||||
|
|
||||||
|
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
|
||||||
|
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
|
||||||
|
@click.command()
|
||||||
|
def run(reload: bool, port: str) -> None:
|
||||||
|
"""运行 AstrBot"""
|
||||||
|
try:
|
||||||
|
os.environ["ASTRBOT_CLI"] = "1"
|
||||||
|
astrbot_root = get_astrbot_root()
|
||||||
|
|
||||||
|
if not check_astrbot_root(astrbot_root):
|
||||||
|
raise click.ClickException(
|
||||||
|
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||||
|
sys.path.insert(0, str(astrbot_root))
|
||||||
|
|
||||||
|
if port:
|
||||||
|
os.environ["DASHBOARD_PORT"] = port
|
||||||
|
|
||||||
|
if reload:
|
||||||
|
click.echo("启用插件自动重载")
|
||||||
|
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||||
|
|
||||||
|
lock_file = astrbot_root / "astrbot.lock"
|
||||||
|
lock = FileLock(lock_file, timeout=5)
|
||||||
|
with lock.acquire():
|
||||||
|
asyncio.run(run_astrbot(astrbot_root))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
click.echo("AstrBot 已关闭...")
|
||||||
|
except Timeout:
|
||||||
|
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||||
|
except Exception as e:
|
||||||
|
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
|
||||||
18
astrbot/cli/utils/__init__.py
Normal file
18
astrbot/cli/utils/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from .basic import (
|
||||||
|
get_astrbot_root,
|
||||||
|
check_astrbot_root,
|
||||||
|
check_dashboard,
|
||||||
|
)
|
||||||
|
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
||||||
|
from .version_comparator import VersionComparator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_astrbot_root",
|
||||||
|
"check_astrbot_root",
|
||||||
|
"check_dashboard",
|
||||||
|
"get_git_repo",
|
||||||
|
"manage_plugin",
|
||||||
|
"build_plug_list",
|
||||||
|
"VersionComparator",
|
||||||
|
"PluginStatus",
|
||||||
|
]
|
||||||
67
astrbot/cli/utils/basic.py
Normal file
67
astrbot/cli/utils/basic.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
def check_astrbot_root(path: str | Path) -> bool:
|
||||||
|
"""检查路径是否为 AstrBot 根目录"""
|
||||||
|
if not isinstance(path, Path):
|
||||||
|
path = Path(path)
|
||||||
|
if not path.exists() or not path.is_dir():
|
||||||
|
return False
|
||||||
|
if not (path / ".astrbot").exists():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_astrbot_root() -> Path:
|
||||||
|
"""获取Astrbot根目录路径"""
|
||||||
|
return Path.cwd()
|
||||||
|
|
||||||
|
|
||||||
|
async def check_dashboard(astrbot_root: Path) -> None:
|
||||||
|
"""检查是否安装了dashboard"""
|
||||||
|
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||||
|
from astrbot.core.config.default import VERSION
|
||||||
|
from .version_comparator import VersionComparator
|
||||||
|
|
||||||
|
try:
|
||||||
|
dashboard_version = await get_dashboard_version()
|
||||||
|
match dashboard_version:
|
||||||
|
case None:
|
||||||
|
click.echo("未安装管理面板")
|
||||||
|
if click.confirm(
|
||||||
|
"是否安装管理面板?",
|
||||||
|
default=True,
|
||||||
|
abort=True,
|
||||||
|
):
|
||||||
|
click.echo("正在安装管理面板...")
|
||||||
|
await download_dashboard(
|
||||||
|
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||||
|
)
|
||||||
|
click.echo("管理面板安装完成")
|
||||||
|
|
||||||
|
case str():
|
||||||
|
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||||
|
click.echo("管理面板已是最新版本")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
version = dashboard_version.split("v")[1]
|
||||||
|
click.echo(f"管理面板版本: {version}")
|
||||||
|
await download_dashboard(
|
||||||
|
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"下载管理面板失败: {e}")
|
||||||
|
return
|
||||||
|
except FileNotFoundError:
|
||||||
|
click.echo("初始化管理面板目录...")
|
||||||
|
try:
|
||||||
|
await download_dashboard(
|
||||||
|
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||||
|
)
|
||||||
|
click.echo("管理面板初始化完成")
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"下载管理面板失败: {e}")
|
||||||
|
return
|
||||||
230
astrbot/cli/utils/plugin.py
Normal file
230
astrbot/cli/utils/plugin.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import yaml
|
||||||
|
from enum import Enum
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import click
|
||||||
|
from .version_comparator import VersionComparator
|
||||||
|
|
||||||
|
|
||||||
|
class PluginStatus(str, Enum):
|
||||||
|
INSTALLED = "已安装"
|
||||||
|
NEED_UPDATE = "需更新"
|
||||||
|
NOT_INSTALLED = "未安装"
|
||||||
|
NOT_PUBLISHED = "未发布"
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||||
|
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||||
|
temp_dir = Path(tempfile.mkdtemp())
|
||||||
|
try:
|
||||||
|
# 解析仓库信息
|
||||||
|
repo_namespace = url.split("/")[-2:]
|
||||||
|
author = repo_namespace[0]
|
||||||
|
repo = repo_namespace[1]
|
||||||
|
|
||||||
|
# 尝试获取最新的 release
|
||||||
|
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||||
|
try:
|
||||||
|
with httpx.Client(
|
||||||
|
proxy=proxy if proxy else None, follow_redirects=True
|
||||||
|
) as client:
|
||||||
|
resp = client.get(release_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
releases = resp.json()
|
||||||
|
|
||||||
|
if releases:
|
||||||
|
# 使用最新的 release
|
||||||
|
download_url = releases[0]["zipball_url"]
|
||||||
|
else:
|
||||||
|
# 没有 release,使用默认分支
|
||||||
|
click.echo(f"正在从默认分支下载 {author}/{repo}")
|
||||||
|
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
|
||||||
|
download_url = url
|
||||||
|
|
||||||
|
# 应用代理
|
||||||
|
if proxy:
|
||||||
|
download_url = f"{proxy}/{download_url}"
|
||||||
|
|
||||||
|
# 下载并解压
|
||||||
|
with httpx.Client(
|
||||||
|
proxy=proxy if proxy else None, follow_redirects=True
|
||||||
|
) as client:
|
||||||
|
resp = client.get(download_url)
|
||||||
|
if (
|
||||||
|
resp.status_code == 404
|
||||||
|
and "archive/refs/heads/master.zip" in download_url
|
||||||
|
):
|
||||||
|
alt_url = download_url.replace("master.zip", "main.zip")
|
||||||
|
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||||
|
resp = client.get(alt_url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
else:
|
||||||
|
resp.raise_for_status()
|
||||||
|
zip_content = BytesIO(resp.content)
|
||||||
|
with ZipFile(zip_content) as z:
|
||||||
|
z.extractall(temp_dir)
|
||||||
|
namelist = z.namelist()
|
||||||
|
root_dir = Path(namelist[0]).parts[0] if namelist else ""
|
||||||
|
if target_path.exists():
|
||||||
|
shutil.rmtree(target_path)
|
||||||
|
shutil.move(temp_dir / root_dir, target_path)
|
||||||
|
finally:
|
||||||
|
if temp_dir.exists():
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||||
|
"""从 metadata.yaml 文件加载插件元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_dir: 插件目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||||
|
"""
|
||||||
|
yaml_path = plugin_dir / "metadata.yaml"
|
||||||
|
if yaml_path.exists():
|
||||||
|
try:
|
||||||
|
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def build_plug_list(plugins_dir: Path) -> list:
|
||||||
|
"""构建插件列表,包含本地和在线插件信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugins_dir (Path): 插件目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 包含插件信息的字典列表
|
||||||
|
"""
|
||||||
|
# 获取本地插件信息
|
||||||
|
result = []
|
||||||
|
if plugins_dir.exists():
|
||||||
|
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||||
|
plugin_dir = plugins_dir / plugin_name
|
||||||
|
|
||||||
|
# 从 metadata.yaml 加载元数据
|
||||||
|
metadata = load_yaml_metadata(plugin_dir)
|
||||||
|
|
||||||
|
# 如果成功加载元数据,添加到结果列表
|
||||||
|
if metadata and all(
|
||||||
|
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||||
|
):
|
||||||
|
result.append({
|
||||||
|
"name": str(metadata.get("name", "")),
|
||||||
|
"desc": str(metadata.get("desc", "")),
|
||||||
|
"version": str(metadata.get("version", "")),
|
||||||
|
"author": str(metadata.get("author", "")),
|
||||||
|
"repo": str(metadata.get("repo", "")),
|
||||||
|
"status": PluginStatus.INSTALLED,
|
||||||
|
"local_path": str(plugin_dir),
|
||||||
|
})
|
||||||
|
|
||||||
|
# 获取在线插件列表
|
||||||
|
online_plugins = []
|
||||||
|
try:
|
||||||
|
with httpx.Client() as client:
|
||||||
|
resp = client.get("https://api.soulter.top/astrbot/plugins")
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
for plugin_id, plugin_info in data.items():
|
||||||
|
online_plugins.append({
|
||||||
|
"name": str(plugin_id),
|
||||||
|
"desc": str(plugin_info.get("desc", "")),
|
||||||
|
"version": str(plugin_info.get("version", "")),
|
||||||
|
"author": str(plugin_info.get("author", "")),
|
||||||
|
"repo": str(plugin_info.get("repo", "")),
|
||||||
|
"status": PluginStatus.NOT_INSTALLED,
|
||||||
|
"local_path": None,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||||
|
|
||||||
|
# 与在线插件比对,更新状态
|
||||||
|
online_plugin_names = {plugin["name"] for plugin in online_plugins}
|
||||||
|
for local_plugin in result:
|
||||||
|
if local_plugin["name"] in online_plugin_names:
|
||||||
|
# 查找对应的在线插件
|
||||||
|
online_plugin = next(
|
||||||
|
p for p in online_plugins if p["name"] == local_plugin["name"]
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
VersionComparator.compare_version(
|
||||||
|
local_plugin["version"], online_plugin["version"]
|
||||||
|
)
|
||||||
|
< 0
|
||||||
|
):
|
||||||
|
local_plugin["status"] = PluginStatus.NEED_UPDATE
|
||||||
|
else:
|
||||||
|
# 本地插件未在线上发布
|
||||||
|
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
|
||||||
|
|
||||||
|
# 添加未安装的在线插件
|
||||||
|
for online_plugin in online_plugins:
|
||||||
|
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
|
||||||
|
result.append(online_plugin)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def manage_plugin(
|
||||||
|
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""安装或更新插件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin (dict): 插件信息字典
|
||||||
|
plugins_dir (Path): 插件目录
|
||||||
|
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||||
|
proxy (str, optional): 代理服务器地址
|
||||||
|
"""
|
||||||
|
plugin_name = plugin["name"]
|
||||||
|
repo_url = plugin["repo"]
|
||||||
|
|
||||||
|
# 如果是更新且有本地路径,直接使用本地路径
|
||||||
|
if is_update and plugin.get("local_path"):
|
||||||
|
target_path = Path(plugin["local_path"])
|
||||||
|
else:
|
||||||
|
target_path = plugins_dir / plugin_name
|
||||||
|
|
||||||
|
backup_path = Path(f"{target_path}_backup") if is_update else None
|
||||||
|
|
||||||
|
# 检查插件是否存在
|
||||||
|
if is_update and not target_path.exists():
|
||||||
|
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||||
|
|
||||||
|
# 备份现有插件
|
||||||
|
if is_update and backup_path.exists():
|
||||||
|
shutil.rmtree(backup_path)
|
||||||
|
if is_update:
|
||||||
|
shutil.copytree(target_path, backup_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
click.echo(
|
||||||
|
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
||||||
|
)
|
||||||
|
get_git_repo(repo_url, target_path, proxy)
|
||||||
|
|
||||||
|
# 更新成功,删除备份
|
||||||
|
if is_update and backup_path.exists():
|
||||||
|
shutil.rmtree(backup_path)
|
||||||
|
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||||
|
except Exception as e:
|
||||||
|
if target_path.exists():
|
||||||
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
|
if is_update and backup_path.exists():
|
||||||
|
shutil.move(backup_path, target_path)
|
||||||
|
raise click.ClickException(
|
||||||
|
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
||||||
|
)
|
||||||
92
astrbot/cli/utils/version_comparator.py
Normal file
92
astrbot/cli/utils/version_comparator.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
拷贝自 astrbot.core.utils.version_comparator
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class VersionComparator:
|
||||||
|
@staticmethod
|
||||||
|
def compare_version(v1: str, v2: str) -> int:
|
||||||
|
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||||
|
|
||||||
|
参考: https://semver.org/lang/zh-CN/
|
||||||
|
|
||||||
|
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||||
|
"""
|
||||||
|
v1 = v1.lower().replace("v", "")
|
||||||
|
v2 = v2.lower().replace("v", "")
|
||||||
|
|
||||||
|
def split_version(version):
|
||||||
|
match = re.match(
|
||||||
|
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
if not match:
|
||||||
|
return [], None
|
||||||
|
major_minor_patch = match.group(1).split(".")
|
||||||
|
prerelease = match.group(2)
|
||||||
|
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||||
|
parts = [int(x) for x in major_minor_patch]
|
||||||
|
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||||
|
return parts, prerelease
|
||||||
|
|
||||||
|
v1_parts, v1_prerelease = split_version(v1)
|
||||||
|
v2_parts, v2_prerelease = split_version(v2)
|
||||||
|
|
||||||
|
# 比较数字部分
|
||||||
|
length = max(len(v1_parts), len(v2_parts))
|
||||||
|
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||||
|
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||||
|
|
||||||
|
for i in range(length):
|
||||||
|
if v1_parts[i] > v2_parts[i]:
|
||||||
|
return 1
|
||||||
|
elif v1_parts[i] < v2_parts[i]:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# 比较预发布标签
|
||||||
|
if v1_prerelease is None and v2_prerelease is not None:
|
||||||
|
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||||
|
elif v1_prerelease is not None and v2_prerelease is None:
|
||||||
|
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||||
|
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||||
|
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||||
|
for i in range(len_pre):
|
||||||
|
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||||
|
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
|
||||||
|
|
||||||
|
if p1 is None and p2 is not None:
|
||||||
|
return -1
|
||||||
|
elif p1 is not None and p2 is None:
|
||||||
|
return 1
|
||||||
|
elif isinstance(p1, int) and isinstance(p2, str):
|
||||||
|
return -1
|
||||||
|
elif isinstance(p1, str) and isinstance(p2, int):
|
||||||
|
return 1
|
||||||
|
elif isinstance(p1, int) and isinstance(p2, int):
|
||||||
|
if p1 > p2:
|
||||||
|
return 1
|
||||||
|
elif p1 < p2:
|
||||||
|
return -1
|
||||||
|
elif isinstance(p1, str) and isinstance(p2, str):
|
||||||
|
if p1 > p2:
|
||||||
|
return 1
|
||||||
|
elif p1 < p2:
|
||||||
|
return -1
|
||||||
|
return 0 # 预发布标签完全相同
|
||||||
|
|
||||||
|
return 0 # 数字部分和预发布标签都相同
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_prerelease(prerelease):
|
||||||
|
if not prerelease:
|
||||||
|
return None
|
||||||
|
parts = prerelease.split(".")
|
||||||
|
result = []
|
||||||
|
for part in parts:
|
||||||
|
if part.isdigit():
|
||||||
|
result.append(int(part))
|
||||||
|
else:
|
||||||
|
result.append(part)
|
||||||
|
return result
|
||||||
@@ -7,20 +7,28 @@ from astrbot.core.utils.pip_installer import PipInstaller
|
|||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from astrbot.core.config.default import DB_PATH
|
from astrbot.core.config.default import DB_PATH
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
from astrbot.core.file_token_service import FileTokenService
|
||||||
|
from .utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
os.makedirs("data", exist_ok=True)
|
# 初始化数据存储文件夹
|
||||||
|
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||||
|
|
||||||
|
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||||
|
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||||
|
|
||||||
astrbot_config = AstrBotConfig()
|
astrbot_config = AstrBotConfig()
|
||||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||||
html_renderer = HtmlRenderer(t2i_base_url)
|
html_renderer = HtmlRenderer(t2i_base_url)
|
||||||
logger = LogManager.GetLogger(log_name="astrbot")
|
logger = LogManager.GetLogger(log_name="astrbot")
|
||||||
|
|
||||||
if os.environ.get("TESTING", ""):
|
|
||||||
logger.setLevel("DEBUG")
|
|
||||||
|
|
||||||
db_helper = SQLiteDatabase(DB_PATH)
|
db_helper = SQLiteDatabase(DB_PATH)
|
||||||
sp = SharedPreferences() # 简单的偏好设置存储
|
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||||
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
sp = SharedPreferences()
|
||||||
|
# 文件令牌服务
|
||||||
|
file_token_service = FileTokenService()
|
||||||
|
pip_installer = PipInstaller(
|
||||||
|
astrbot_config.get("pip_install_arg", ""),
|
||||||
|
astrbot_config.get("pypi_index_url", None),
|
||||||
|
)
|
||||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ import logging
|
|||||||
import enum
|
import enum
|
||||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
|
|
||||||
@@ -45,8 +46,6 @@ class AstrBotConfig(dict):
|
|||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||||
conf_str = f.read()
|
conf_str = f.read()
|
||||||
if conf_str.startswith("/ufeff"): # remove BOM
|
|
||||||
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
|
|
||||||
conf = json.loads(conf_str)
|
conf = json.loads(conf_str)
|
||||||
|
|
||||||
# 检查配置完整性,并插入
|
# 检查配置完整性,并插入
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,10 @@
|
|||||||
|
"""
|
||||||
|
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
|
||||||
|
|
||||||
|
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||||||
|
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||||
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -11,24 +18,34 @@ class ConversationManager:
|
|||||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||||
|
|
||||||
def __init__(self, db_helper: BaseDatabase):
|
def __init__(self, db_helper: BaseDatabase):
|
||||||
|
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||||
self.db = db_helper
|
self.db = db_helper
|
||||||
self.save_interval = 60 # 每 60 秒保存一次
|
self.save_interval = 60 # 每 60 秒保存一次
|
||||||
self._start_periodic_save()
|
self._start_periodic_save()
|
||||||
|
|
||||||
def _start_periodic_save(self):
|
def _start_periodic_save(self):
|
||||||
|
"""启动定时保存任务"""
|
||||||
asyncio.create_task(self._periodic_save())
|
asyncio.create_task(self._periodic_save())
|
||||||
|
|
||||||
async def _periodic_save(self):
|
async def _periodic_save(self):
|
||||||
|
"""定时保存会话对话映射关系到存储中"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(self.save_interval)
|
await asyncio.sleep(self.save_interval)
|
||||||
self._save_to_storage()
|
self._save_to_storage()
|
||||||
|
|
||||||
def _save_to_storage(self):
|
def _save_to_storage(self):
|
||||||
|
"""保存会话对话映射关系到存储中"""
|
||||||
sp.put("session_conversation", self.session_conversations)
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||||
"""新建对话,并将当前会话的对话转移到新对话"""
|
"""新建对话,并将当前会话的对话转移到新对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
conversation_id = str(uuid.uuid4())
|
conversation_id = str(uuid.uuid4())
|
||||||
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||||
self.session_conversations[unified_msg_origin] = conversation_id
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
@@ -36,14 +53,24 @@ class ConversationManager:
|
|||||||
return conversation_id
|
return conversation_id
|
||||||
|
|
||||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||||
"""切换会话的对话"""
|
"""切换会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
self.session_conversations[unified_msg_origin] = conversation_id
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
sp.put("session_conversation", self.session_conversations)
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
async def delete_conversation(
|
async def delete_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str = None
|
self, unified_msg_origin: str, conversation_id: str = None
|
||||||
):
|
):
|
||||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话"""
|
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||||
@@ -51,23 +78,48 @@ class ConversationManager:
|
|||||||
sp.put("session_conversation", self.session_conversations)
|
sp.put("session_conversation", self.session_conversations)
|
||||||
|
|
||||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||||
"""获取会话当前的对话 ID"""
|
"""获取会话当前的对话 ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
"""
|
||||||
return self.session_conversations.get(unified_msg_origin, None)
|
return self.session_conversations.get(unified_msg_origin, None)
|
||||||
|
|
||||||
async def get_conversation(
|
async def get_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str
|
self, unified_msg_origin: str, conversation_id: str
|
||||||
) -> Conversation:
|
) -> Conversation:
|
||||||
"""获取会话的对话"""
|
"""获取会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
Returns:
|
||||||
|
conversation (Conversation): 对话对象
|
||||||
|
"""
|
||||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||||
|
|
||||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||||
"""获取会话的所有对话"""
|
"""获取会话的所有对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
Returns:
|
||||||
|
conversations (List[Conversation]): 对话对象列表
|
||||||
|
"""
|
||||||
return self.db.get_conversations(unified_msg_origin)
|
return self.db.get_conversations(unified_msg_origin)
|
||||||
|
|
||||||
async def update_conversation(
|
async def update_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||||
):
|
):
|
||||||
"""更新会话的对话"""
|
"""更新会话的对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||||
|
"""
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.update_conversation(
|
self.db.update_conversation(
|
||||||
user_id=unified_msg_origin,
|
user_id=unified_msg_origin,
|
||||||
@@ -76,7 +128,12 @@ class ConversationManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||||
"""更新会话的对话标题"""
|
"""更新会话的对话标题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
title (str): 对话标题
|
||||||
|
"""
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.update_conversation_title(
|
self.db.update_conversation_title(
|
||||||
@@ -86,7 +143,12 @@ class ConversationManager:
|
|||||||
async def update_conversation_persona_id(
|
async def update_conversation_persona_id(
|
||||||
self, unified_msg_origin: str, persona_id: str
|
self, unified_msg_origin: str, persona_id: str
|
||||||
):
|
):
|
||||||
"""更新会话的对话 Persona ID"""
|
"""更新会话的对话 Persona ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
persona_id (str): 对话 Persona ID
|
||||||
|
"""
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
self.db.update_conversation_persona_id(
|
self.db.update_conversation_persona_id(
|
||||||
@@ -96,6 +158,14 @@ class ConversationManager:
|
|||||||
async def get_human_readable_context(
|
async def get_human_readable_context(
|
||||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||||
):
|
):
|
||||||
|
"""获取人类可读的上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
page (int): 页码
|
||||||
|
page_size (int): 每页大小
|
||||||
|
"""
|
||||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||||
history = json.loads(conversation.history)
|
history = json.loads(conversation.history)
|
||||||
|
|
||||||
@@ -105,7 +175,15 @@ class ConversationManager:
|
|||||||
if record["role"] == "user":
|
if record["role"] == "user":
|
||||||
temp_contexts.append(f"User: {record['content']}")
|
temp_contexts.append(f"User: {record['content']}")
|
||||||
elif record["role"] == "assistant":
|
elif record["role"] == "assistant":
|
||||||
temp_contexts.append(f"Assistant: {record['content']}")
|
if "content" in record and record["content"]:
|
||||||
|
temp_contexts.append(f"Assistant: {record['content']}")
|
||||||
|
elif "tool_calls" in record:
|
||||||
|
tool_calls_str = json.dumps(
|
||||||
|
record["tool_calls"], ensure_ascii=False
|
||||||
|
)
|
||||||
|
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||||
|
else:
|
||||||
|
temp_contexts.append("Assistant: [未知的内容]")
|
||||||
contexts.insert(0, temp_contexts)
|
contexts.insert(0, temp_contexts)
|
||||||
temp_contexts = []
|
temp_contexts = []
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||||
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||||
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 初始化所有组件
|
||||||
|
2. 启动事件总线和任务, 所有任务都在这里运行
|
||||||
|
3. 执行启动完成事件钩子
|
||||||
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -17,39 +28,54 @@ from astrbot.core.db import BaseDatabase
|
|||||||
from astrbot.core.updator import AstrBotUpdator
|
from astrbot.core.updator import AstrBotUpdator
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.config.default import VERSION
|
from astrbot.core.config.default import VERSION
|
||||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
|
||||||
from astrbot.core.conversation_mgr import ConversationManager
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star_handler import star_map
|
from astrbot.core.star.star_handler import star_map
|
||||||
|
|
||||||
|
|
||||||
class AstrBotCoreLifecycle:
|
class AstrBotCoreLifecycle:
|
||||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
"""
|
||||||
self.log_broker = log_broker
|
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||||
self.astrbot_config = astrbot_config
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||||
self.db = db
|
EventBus 等。
|
||||||
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||||
|
self.log_broker = log_broker # 初始化日志代理
|
||||||
|
self.astrbot_config = astrbot_config # 初始化配置
|
||||||
|
self.db = db # 初始化数据库
|
||||||
|
|
||||||
|
# 根据环境变量设置代理
|
||||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||||
os.environ["no_proxy"] = "localhost"
|
os.environ["no_proxy"] = "localhost"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
"""
|
||||||
|
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 初始化日志代理
|
||||||
logger.info("AstrBot v" + VERSION)
|
logger.info("AstrBot v" + VERSION)
|
||||||
if os.environ.get("TESTING", ""):
|
if os.environ.get("TESTING", ""):
|
||||||
logger.setLevel("DEBUG")
|
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
|
||||||
else:
|
else:
|
||||||
logger.setLevel(self.astrbot_config["log_level"])
|
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||||
self.event_queue = Queue()
|
|
||||||
self.event_queue.closed = False
|
|
||||||
|
|
||||||
|
# 初始化事件队列
|
||||||
|
self.event_queue = Queue()
|
||||||
|
|
||||||
|
# 初始化供应商管理器
|
||||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||||
|
|
||||||
|
# 初始化平台管理器
|
||||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||||
|
|
||||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
# 初始化对话管理器
|
||||||
|
|
||||||
self.conversation_manager = ConversationManager(self.db)
|
self.conversation_manager = ConversationManager(self.db)
|
||||||
|
|
||||||
|
# 初始化提供给插件的上下文
|
||||||
self.star_context = Context(
|
self.star_context = Context(
|
||||||
self.event_queue,
|
self.event_queue,
|
||||||
self.astrbot_config,
|
self.astrbot_config,
|
||||||
@@ -57,35 +83,51 @@ class AstrBotCoreLifecycle:
|
|||||||
self.provider_manager,
|
self.provider_manager,
|
||||||
self.platform_manager,
|
self.platform_manager,
|
||||||
self.conversation_manager,
|
self.conversation_manager,
|
||||||
self.knowledge_db_manager,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 初始化插件管理器
|
||||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||||
|
|
||||||
|
# 扫描、注册插件、实例化插件类
|
||||||
await self.plugin_manager.reload()
|
await self.plugin_manager.reload()
|
||||||
"""扫描、注册插件、实例化插件类"""
|
|
||||||
|
|
||||||
|
# 根据配置实例化各个 Provider
|
||||||
await self.provider_manager.initialize()
|
await self.provider_manager.initialize()
|
||||||
"""根据配置实例化各个 Provider"""
|
|
||||||
|
|
||||||
|
# 初始化消息事件流水线调度器
|
||||||
self.pipeline_scheduler = PipelineScheduler(
|
self.pipeline_scheduler = PipelineScheduler(
|
||||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||||
)
|
)
|
||||||
await self.pipeline_scheduler.initialize()
|
await self.pipeline_scheduler.initialize()
|
||||||
"""初始化消息事件流水线调度器"""
|
|
||||||
|
|
||||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
# 初始化更新器
|
||||||
|
self.astrbot_updator = AstrBotUpdator()
|
||||||
|
|
||||||
|
# 初始化事件总线
|
||||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||||
|
|
||||||
|
# 记录启动时间
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
|
# 初始化当前任务列表
|
||||||
self.curr_tasks: List[asyncio.Task] = []
|
self.curr_tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
|
# 根据配置实例化各个平台适配器
|
||||||
await self.platform_manager.initialize()
|
await self.platform_manager.initialize()
|
||||||
"""根据配置实例化各个平台适配器"""
|
|
||||||
|
# 初始化关闭控制面板的事件
|
||||||
|
self.dashboard_shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
|
"""加载事件总线和任务并初始化"""
|
||||||
|
|
||||||
|
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||||
|
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||||
event_bus_task = asyncio.create_task(
|
event_bus_task = asyncio.create_task(
|
||||||
self.event_bus.dispatch(), name="event_bus"
|
self.event_bus.dispatch(), name="event_bus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||||
extra_tasks = []
|
extra_tasks = []
|
||||||
for task in self.star_context._register_tasks:
|
for task in self.star_context._register_tasks:
|
||||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||||
@@ -99,17 +141,24 @@ class AstrBotCoreLifecycle:
|
|||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
async def _task_wrapper(self, task: asyncio.Task):
|
async def _task_wrapper(self, task: asyncio.Task):
|
||||||
|
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (asyncio.Task): 要执行的异步任务
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass # 任务被取消, 静默处理
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
|
||||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||||
for line in traceback.format_exc().split("\n"):
|
for line in traceback.format_exc().split("\n"):
|
||||||
logger.error(f"| {line}")
|
logger.error(f"| {line}")
|
||||||
logger.error("-------")
|
logger.error("-------")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
|
||||||
self._load()
|
self._load()
|
||||||
logger.info("AstrBot 启动完成。")
|
logger.info("AstrBot 启动完成。")
|
||||||
|
|
||||||
@@ -126,15 +175,29 @@ class AstrBotCoreLifecycle:
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# 同时运行curr_tasks中的所有任务
|
||||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.event_queue.closed = True
|
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
|
||||||
|
# 请求停止所有正在运行的异步任务
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
await self.provider_manager.terminate()
|
for plugin in self.plugin_manager.context.get_all_stars():
|
||||||
|
try:
|
||||||
|
await self.plugin_manager._terminate_plugin(plugin)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.provider_manager.terminate()
|
||||||
|
await self.platform_manager.terminate()
|
||||||
|
self.dashboard_shutdown_event.set()
|
||||||
|
|
||||||
|
# 再次遍历curr_tasks等待每个任务真正结束
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
@@ -143,13 +206,17 @@ class AstrBotCoreLifecycle:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||||
|
|
||||||
def restart(self):
|
async def restart(self):
|
||||||
self.event_queue.closed = True
|
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||||
|
await self.provider_manager.terminate()
|
||||||
|
await self.platform_manager.terminate()
|
||||||
|
self.dashboard_shutdown_event.set()
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
def load_platform(self) -> List[asyncio.Task]:
|
def load_platform(self) -> List[asyncio.Task]:
|
||||||
|
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||||
tasks = []
|
tasks = []
|
||||||
platform_insts = self.platform_manager.get_insts()
|
platform_insts = self.platform_manager.get_insts()
|
||||||
for platform_inst in platform_insts:
|
for platform_inst in platform_insts:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List, Dict, Any, Tuple
|
||||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||||
|
|
||||||
|
|
||||||
@@ -117,3 +117,45 @@ class BaseDatabase(abc.ABC):
|
|||||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||||
"""更新 Conversation Persona ID"""
|
"""更新 Conversation Persona ID"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_all_conversations(
|
||||||
|
self, page: int = 1, page_size: int = 20
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取所有对话,支持分页
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码,从1开始
|
||||||
|
page_size: 每页数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
platforms: List[str] = None,
|
||||||
|
message_types: List[str] = None,
|
||||||
|
search_query: str = None,
|
||||||
|
exclude_ids: List[str] = None,
|
||||||
|
exclude_platforms: List[str] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取筛选后的对话列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
platforms: 平台筛选列表
|
||||||
|
message_types: 消息类型筛选列表
|
||||||
|
search_query: 搜索关键词
|
||||||
|
exclude_ids: 排除的用户ID列表
|
||||||
|
exclude_platforms: 排除的平台列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import List
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Platform:
|
class Platform:
|
||||||
|
"""平台使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
@@ -13,6 +15,8 @@ class Platform:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Provider:
|
class Provider:
|
||||||
|
"""供应商使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
@@ -20,6 +24,8 @@ class Provider:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Plugin:
|
class Plugin:
|
||||||
|
"""插件使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
@@ -27,6 +33,8 @@ class Plugin:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Command:
|
class Command:
|
||||||
|
"""命令使用统计数据"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
timestamp: int
|
timestamp: int
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
|
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
|
||||||
from . import BaseDatabase
|
from . import BaseDatabase
|
||||||
from typing import Tuple
|
from typing import Tuple, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase(BaseDatabase):
|
class SQLiteDatabase(BaseDatabase):
|
||||||
@@ -128,24 +128,23 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
c = self._get_conn(self.db_path).cursor()
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
where_clause = ""
|
conditions = []
|
||||||
if session_id or provider_type:
|
params = []
|
||||||
where_clause += " WHERE "
|
|
||||||
has = False
|
if session_id:
|
||||||
if session_id:
|
conditions.append("session_id = ?")
|
||||||
where_clause += f"session_id = '{session_id}'"
|
params.append(session_id)
|
||||||
has = True
|
|
||||||
if provider_type:
|
if provider_type:
|
||||||
if has:
|
conditions.append("provider_type = ?")
|
||||||
where_clause += " AND "
|
params.append(provider_type)
|
||||||
where_clause += f"provider_type = '{provider_type}'"
|
|
||||||
|
sql = "SELECT * FROM llm_history"
|
||||||
|
if conditions:
|
||||||
|
sql += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
c.execute(sql, params)
|
||||||
|
|
||||||
c.execute(
|
|
||||||
"""
|
|
||||||
SELECT * FROM llm_history
|
|
||||||
"""
|
|
||||||
+ where_clause
|
|
||||||
)
|
|
||||||
res = c.fetchall()
|
res = c.fetchall()
|
||||||
histories = []
|
histories = []
|
||||||
for row in res:
|
for row in res:
|
||||||
@@ -389,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
if res:
|
if res:
|
||||||
return ATRIVision(*res)
|
return ATRIVision(*res)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_all_conversations(
|
||||||
|
self, page: int = 1, page_size: int = 20
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||||
|
try:
|
||||||
|
c = self.conn.cursor()
|
||||||
|
except sqlite3.ProgrammingError:
|
||||||
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取总记录数
|
||||||
|
c.execute("""
|
||||||
|
SELECT COUNT(*) FROM webchat_conversation
|
||||||
|
""")
|
||||||
|
total_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 获取分页数据,按更新时间降序排序
|
||||||
|
c.execute(
|
||||||
|
"""
|
||||||
|
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||||
|
FROM webchat_conversation
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(page_size, offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = c.fetchall()
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||||
|
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
|
||||||
|
safe_cid = str(cid) if cid else "unknown"
|
||||||
|
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||||
|
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"cid": safe_cid,
|
||||||
|
"title": title or f"对话 {display_cid}",
|
||||||
|
"persona_id": persona_id or "",
|
||||||
|
"created_at": created_at or 0,
|
||||||
|
"updated_at": updated_at or 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total_count
|
||||||
|
|
||||||
|
except Exception as _:
|
||||||
|
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||||
|
return [], 0
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
def get_filtered_conversations(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
platforms: List[str] = None,
|
||||||
|
message_types: List[str] = None,
|
||||||
|
search_query: str = None,
|
||||||
|
exclude_ids: List[str] = None,
|
||||||
|
exclude_platforms: List[str] = None,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""获取筛选后的对话列表"""
|
||||||
|
try:
|
||||||
|
c = self.conn.cursor()
|
||||||
|
except sqlite3.ProgrammingError:
|
||||||
|
c = self._get_conn(self.db_path).cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建查询条件
|
||||||
|
where_clauses = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
# 平台筛选
|
||||||
|
if platforms and len(platforms) > 0:
|
||||||
|
platform_conditions = []
|
||||||
|
for platform in platforms:
|
||||||
|
platform_conditions.append("user_id LIKE ?")
|
||||||
|
params.append(f"{platform}:%")
|
||||||
|
|
||||||
|
if platform_conditions:
|
||||||
|
where_clauses.append(f"({' OR '.join(platform_conditions)})")
|
||||||
|
|
||||||
|
# 消息类型筛选
|
||||||
|
if message_types and len(message_types) > 0:
|
||||||
|
message_type_conditions = []
|
||||||
|
for msg_type in message_types:
|
||||||
|
message_type_conditions.append("user_id LIKE ?")
|
||||||
|
params.append(f"%:{msg_type}:%")
|
||||||
|
|
||||||
|
if message_type_conditions:
|
||||||
|
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
||||||
|
|
||||||
|
# 搜索关键词
|
||||||
|
if search_query:
|
||||||
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
|
where_clauses.append(
|
||||||
|
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||||
|
)
|
||||||
|
search_param = f"%{search_query}%"
|
||||||
|
params.extend([search_param, search_param, search_param, search_param])
|
||||||
|
|
||||||
|
# 排除特定用户ID
|
||||||
|
if exclude_ids and len(exclude_ids) > 0:
|
||||||
|
for exclude_id in exclude_ids:
|
||||||
|
where_clauses.append("user_id NOT LIKE ?")
|
||||||
|
params.append(f"{exclude_id}%")
|
||||||
|
|
||||||
|
# 排除特定平台
|
||||||
|
if exclude_platforms and len(exclude_platforms) > 0:
|
||||||
|
for exclude_platform in exclude_platforms:
|
||||||
|
where_clauses.append("user_id NOT LIKE ?")
|
||||||
|
params.append(f"{exclude_platform}:%")
|
||||||
|
|
||||||
|
# 构建完整的 WHERE 子句
|
||||||
|
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||||
|
|
||||||
|
# 构建计数查询
|
||||||
|
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
|
||||||
|
|
||||||
|
# 获取总记录数
|
||||||
|
c.execute(count_sql, params)
|
||||||
|
total_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 构建分页数据查询
|
||||||
|
data_sql = f"""
|
||||||
|
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||||
|
FROM webchat_conversation
|
||||||
|
{where_sql}
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
query_params = params + [page_size, offset]
|
||||||
|
|
||||||
|
# 获取分页数据
|
||||||
|
c.execute(data_sql, query_params)
|
||||||
|
rows = c.fetchall()
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||||
|
# 确保 cid 是字符串类型,否则使用一个默认值
|
||||||
|
safe_cid = str(cid) if cid else "unknown"
|
||||||
|
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||||
|
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"cid": safe_cid,
|
||||||
|
"title": title or f"对话 {display_cid}",
|
||||||
|
"persona_id": persona_id or "",
|
||||||
|
"created_at": created_at or 0,
|
||||||
|
"updated_at": updated_at or 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total_count
|
||||||
|
|
||||||
|
except Exception as _:
|
||||||
|
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||||
|
return [], 0
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
|||||||
@@ -38,11 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||||
user_id TEXT,
|
user_id TEXT, -- 会话 id
|
||||||
cid TEXT,
|
cid TEXT, -- 对话 id
|
||||||
history TEXT,
|
history TEXT,
|
||||||
created_at INTEGER,
|
created_at INTEGER,
|
||||||
updated_at INTEGER,
|
updated_at INTEGER,
|
||||||
title TEXT,
|
title TEXT,
|
||||||
persona_id TEXT
|
persona_id TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
|
PRAGMA encoding = 'UTF-8';
|
||||||
46
astrbot/core/db/vec_db/base.py
Normal file
46
astrbot/core/db/vec_db/base.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import abc
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Result:
|
||||||
|
similarity: float
|
||||||
|
data: dict
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVecDB:
|
||||||
|
async def initialize(self):
|
||||||
|
"""
|
||||||
|
初始化向量数据库
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||||
|
"""
|
||||||
|
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||||
|
"""
|
||||||
|
搜索最相似的文档。
|
||||||
|
Args:
|
||||||
|
query (str): 查询文本
|
||||||
|
top_k (int): 返回的最相似文档的数量
|
||||||
|
Returns:
|
||||||
|
List[Result]: 查询结果
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete(self, doc_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除指定文档。
|
||||||
|
Args:
|
||||||
|
doc_id (str): 要删除的文档 ID
|
||||||
|
Returns:
|
||||||
|
bool: 删除是否成功
|
||||||
|
"""
|
||||||
|
...
|
||||||
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .vec_db import FaissVecDB
|
||||||
|
|
||||||
|
__all__ = ["FaissVecDB"]
|
||||||
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import aiosqlite
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentStorage:
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.connection = None
|
||||||
|
self.sqlite_init_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "sqlite_init.sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||||
|
if not os.path.exists(self.db_path):
|
||||||
|
await self.connect()
|
||||||
|
async with self.connection.cursor() as cursor:
|
||||||
|
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||||
|
sql_script = f.read()
|
||||||
|
await cursor.executescript(sql_script)
|
||||||
|
await self.connection.commit()
|
||||||
|
else:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to the SQLite database."""
|
||||||
|
self.connection = await aiosqlite.connect(self.db_path)
|
||||||
|
|
||||||
|
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||||
|
"""Retrieve documents by metadata filters and ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_filters (dict): The metadata filters to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||||
|
"""
|
||||||
|
# metadata filter -> SQL WHERE clause
|
||||||
|
where_clauses = []
|
||||||
|
values = []
|
||||||
|
for key, val in metadata_filters.items():
|
||||||
|
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||||
|
values.append(val)
|
||||||
|
if ids is not None and len(ids) > 0:
|
||||||
|
ids = [str(i) for i in ids if i != -1]
|
||||||
|
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||||
|
values.extend(ids)
|
||||||
|
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||||
|
|
||||||
|
result = []
|
||||||
|
async with self.connection.cursor() as cursor:
|
||||||
|
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||||
|
await cursor.execute(sql, values)
|
||||||
|
for row in await cursor.fetchall():
|
||||||
|
result.append(await self.tuple_to_dict(row))
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_document_by_doc_id(self, doc_id: str):
|
||||||
|
"""Retrieve a document by its doc_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id (str): The doc_id of the document to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The document data.
|
||||||
|
"""
|
||||||
|
async with self.connection.cursor() as cursor:
|
||||||
|
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return await self.tuple_to_dict(row)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||||
|
"""Retrieve a document by its doc_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id (str): The doc_id.
|
||||||
|
new_text (str): The new text to update the document with.
|
||||||
|
"""
|
||||||
|
async with self.connection.cursor() as cursor:
|
||||||
|
await cursor.execute(
|
||||||
|
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||||
|
)
|
||||||
|
await self.connection.commit()
|
||||||
|
|
||||||
|
async def get_user_ids(self) -> list[str]:
|
||||||
|
"""Retrieve all user IDs from the documents table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of user IDs.
|
||||||
|
"""
|
||||||
|
async with self.connection.cursor() as cursor:
|
||||||
|
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
|
async def tuple_to_dict(self, row):
|
||||||
|
"""Convert a tuple to a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row (tuple): The row to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The converted dictionary.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": row[0],
|
||||||
|
"doc_id": row[1],
|
||||||
|
"text": row[2],
|
||||||
|
"metadata": row[3],
|
||||||
|
"created_at": row[4],
|
||||||
|
"updated_at": row[5],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the connection to the SQLite database."""
|
||||||
|
if self.connection:
|
||||||
|
await self.connection.close()
|
||||||
|
self.connection = None
|
||||||
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ImportError(
|
||||||
|
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||||
|
)
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingStorage:
|
||||||
|
def __init__(self, dimension: int, path: str = None):
|
||||||
|
self.dimension = dimension
|
||||||
|
self.path = path
|
||||||
|
self.index = None
|
||||||
|
if path and os.path.exists(path):
|
||||||
|
self.index = faiss.read_index(path)
|
||||||
|
else:
|
||||||
|
base_index = faiss.IndexFlatL2(dimension)
|
||||||
|
self.index = faiss.IndexIDMap(base_index)
|
||||||
|
self.storage = {}
|
||||||
|
|
||||||
|
async def insert(self, vector: np.ndarray, id: int):
|
||||||
|
"""插入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector (np.ndarray): 要插入的向量
|
||||||
|
id (int): 向量的ID
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果向量的维度与存储的维度不匹配
|
||||||
|
"""
|
||||||
|
if vector.shape[0] != self.dimension:
|
||||||
|
raise ValueError(
|
||||||
|
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||||
|
)
|
||||||
|
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||||
|
self.storage[id] = vector
|
||||||
|
await self.save_index()
|
||||||
|
|
||||||
|
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||||
|
"""搜索最相似的向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector (np.ndarray): 查询向量
|
||||||
|
k (int): 返回的最相似向量的数量
|
||||||
|
Returns:
|
||||||
|
tuple: (距离, 索引)
|
||||||
|
"""
|
||||||
|
faiss.normalize_L2(vector)
|
||||||
|
distances, indices = self.index.search(vector, k)
|
||||||
|
return distances, indices
|
||||||
|
|
||||||
|
async def save_index(self):
|
||||||
|
"""保存索引
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): 保存索引的路径
|
||||||
|
"""
|
||||||
|
faiss.write_index(self.index, self.path)
|
||||||
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
|
||||||
|
CREATE TABLE documents (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
doc_id TEXT NOT NULL,
|
||||||
|
text TEXT NOT NULL,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE documents
|
||||||
|
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
|
||||||
|
ALTER TABLE documents
|
||||||
|
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
|
||||||
|
|
||||||
|
CREATE INDEX idx_documents_user_id ON documents(user_id);
|
||||||
|
CREATE INDEX idx_documents_group_id ON documents(group_id);
|
||||||
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from .document_storage import DocumentStorage
|
||||||
|
from .embedding_storage import EmbeddingStorage
|
||||||
|
from ..base import Result, BaseVecDB
|
||||||
|
from astrbot.core.provider.provider import EmbeddingProvider
|
||||||
|
|
||||||
|
|
||||||
|
class FaissVecDB(BaseVecDB):
|
||||||
|
"""
|
||||||
|
A class to represent a vector database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
doc_store_path: str,
|
||||||
|
index_store_path: str,
|
||||||
|
embedding_provider: EmbeddingProvider,
|
||||||
|
):
|
||||||
|
self.doc_store_path = doc_store_path
|
||||||
|
self.index_store_path = index_store_path
|
||||||
|
self.embedding_provider = embedding_provider
|
||||||
|
self.document_storage = DocumentStorage(doc_store_path)
|
||||||
|
self.embedding_storage = EmbeddingStorage(
|
||||||
|
embedding_provider.get_dim(), index_store_path
|
||||||
|
)
|
||||||
|
self.embedding_provider = embedding_provider
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await self.document_storage.initialize()
|
||||||
|
|
||||||
|
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||||
|
"""
|
||||||
|
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||||
|
"""
|
||||||
|
metadata = metadata or {}
|
||||||
|
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||||
|
|
||||||
|
vector = await self.embedding_provider.get_embedding(content)
|
||||||
|
vector = np.array(vector, dtype=np.float32)
|
||||||
|
async with self.document_storage.connection.cursor() as cursor:
|
||||||
|
await cursor.execute(
|
||||||
|
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||||
|
(str_id, content, json.dumps(metadata)),
|
||||||
|
)
|
||||||
|
await self.document_storage.connection.commit()
|
||||||
|
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||||
|
int_id = result["id"]
|
||||||
|
|
||||||
|
# 插入向量到 FAISS
|
||||||
|
await self.embedding_storage.insert(vector, int_id)
|
||||||
|
return int_id
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||||
|
) -> list[Result]:
|
||||||
|
"""
|
||||||
|
搜索最相似的文档。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): 查询文本
|
||||||
|
k (int): 返回的最相似文档的数量
|
||||||
|
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||||
|
metadata_filters (dict): 元数据过滤器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Result]: 查询结果
|
||||||
|
"""
|
||||||
|
embedding = await self.embedding_provider.get_embedding(query)
|
||||||
|
scores, indices = await self.embedding_storage.search(
|
||||||
|
vector=np.array([embedding]).astype("float32"),
|
||||||
|
k=fetch_k if metadata_filters else k,
|
||||||
|
)
|
||||||
|
# TODO: rerank
|
||||||
|
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||||
|
return []
|
||||||
|
# normalize scores
|
||||||
|
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||||
|
# NOTE: maybe the size is less than k.
|
||||||
|
fetched_docs = await self.document_storage.get_documents(
|
||||||
|
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||||
|
)
|
||||||
|
if not fetched_docs:
|
||||||
|
return []
|
||||||
|
result_docs = []
|
||||||
|
|
||||||
|
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||||
|
for i, indice_idx in enumerate(indices[0]):
|
||||||
|
pos = idx_pos.get(indice_idx)
|
||||||
|
if pos is None:
|
||||||
|
continue
|
||||||
|
fetch_doc = fetched_docs[pos]
|
||||||
|
score = scores[0][i]
|
||||||
|
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||||
|
return result_docs[:k]
|
||||||
|
|
||||||
|
async def delete(self, doc_id: int):
|
||||||
|
"""
|
||||||
|
删除一条文档
|
||||||
|
"""
|
||||||
|
await self.document_storage.connection.execute(
|
||||||
|
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||||
|
)
|
||||||
|
await self.document_storage.connection.commit()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.document_storage.close()
|
||||||
|
|
||||||
|
async def count_documents(self) -> int:
|
||||||
|
"""
|
||||||
|
计算文档数量
|
||||||
|
"""
|
||||||
|
async with self.document_storage.connection.cursor() as cursor:
|
||||||
|
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||||
|
count = await cursor.fetchone()
|
||||||
|
return count[0] if count else 0
|
||||||
@@ -1,3 +1,16 @@
|
|||||||
|
"""
|
||||||
|
事件总线, 用于处理事件的分发和处理
|
||||||
|
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||||
|
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
|
class:
|
||||||
|
EventBus: 事件总线, 用于处理事件的分发和处理
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 维护一个异步队列, 来接受各种消息事件
|
||||||
|
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||||
@@ -6,21 +19,38 @@ from .platform import AstrMessageEvent
|
|||||||
|
|
||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
|
"""事件总线: 用于处理事件的分发和处理
|
||||||
|
|
||||||
|
维护一个异步队列, 来接受各种消息事件
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue # 事件队列
|
||||||
self.pipeline_scheduler = pipeline_scheduler
|
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||||
|
|
||||||
async def dispatch(self):
|
async def dispatch(self):
|
||||||
|
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||||
while True:
|
while True:
|
||||||
event: AstrMessageEvent = await self.event_queue.get()
|
event: AstrMessageEvent = (
|
||||||
self._print_event(event)
|
await self.event_queue.get()
|
||||||
asyncio.create_task(self.pipeline_scheduler.execute(event))
|
) # 从事件队列中获取新的事件
|
||||||
|
self._print_event(event) # 打印日志
|
||||||
|
asyncio.create_task(
|
||||||
|
self.pipeline_scheduler.execute(event)
|
||||||
|
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
def _print_event(self, event: AstrMessageEvent):
|
def _print_event(self, event: AstrMessageEvent):
|
||||||
|
"""用于记录事件信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
"""
|
||||||
|
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||||
if event.get_sender_name():
|
if event.get_sender_name():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||||
)
|
)
|
||||||
|
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||||
|
|||||||
68
astrbot/core/file_token_service.py
Normal file
68
astrbot/core/file_token_service.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class FileTokenService:
|
||||||
|
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||||
|
|
||||||
|
def __init__(self, default_timeout: float = 300):
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.staged_files = {} # token: (file_path, expire_time)
|
||||||
|
self.default_timeout = default_timeout
|
||||||
|
|
||||||
|
async def _cleanup_expired_tokens(self):
|
||||||
|
"""清理过期的令牌"""
|
||||||
|
now = time.time()
|
||||||
|
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
||||||
|
for token in expired_tokens:
|
||||||
|
self.staged_files.pop(token, None)
|
||||||
|
|
||||||
|
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
||||||
|
"""向令牌服务注册一个文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path(str): 文件路径
|
||||||
|
timeout(float): 超时时间,单位秒(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 一个单次令牌
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: 当路径不存在时抛出
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
await self._cleanup_expired_tokens()
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||||
|
|
||||||
|
file_token = str(uuid.uuid4())
|
||||||
|
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
||||||
|
self.staged_files[file_token] = (file_path, expire_time)
|
||||||
|
return file_token
|
||||||
|
|
||||||
|
async def handle_file(self, file_token: str) -> str:
|
||||||
|
"""根据令牌获取文件路径,使用后令牌失效。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_token(str): 注册时返回的令牌
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 文件路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 当令牌不存在或已过期时抛出
|
||||||
|
FileNotFoundError: 当文件本身已被删除时抛出
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
await self._cleanup_expired_tokens()
|
||||||
|
|
||||||
|
if file_token not in self.staged_files:
|
||||||
|
raise KeyError(f"无效或过期的文件 token: {file_token}")
|
||||||
|
|
||||||
|
file_path, _ = self.staged_files.pop(file_token)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||||
|
return file_path
|
||||||
@@ -1,35 +1,49 @@
|
|||||||
|
"""
|
||||||
|
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
||||||
|
2. 运行核心生命周期任务和仪表板服务器
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from .server import AstrBotDashboard
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core import LogBroker
|
from astrbot.core import LogBroker
|
||||||
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
|
|
||||||
|
|
||||||
class AstrBotDashBoardLifecycle:
|
class InitialLoader:
|
||||||
|
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
|
||||||
|
|
||||||
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
self.dashboard_server = None
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||||
|
|
||||||
core_task = []
|
|
||||||
try:
|
try:
|
||||||
await core_lifecycle.initialize()
|
await core_lifecycle.initialize()
|
||||||
core_task = core_lifecycle.start()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(traceback.format_exc())
|
logger.critical(traceback.format_exc())
|
||||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||||
|
return
|
||||||
|
|
||||||
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
|
core_task = core_lifecycle.start()
|
||||||
task = asyncio.gather(core_task, self.dashboard_server.run())
|
|
||||||
|
self.dashboard_server = AstrBotDashboard(
|
||||||
|
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||||
|
)
|
||||||
|
task = asyncio.gather(
|
||||||
|
core_task, self.dashboard_server.run()
|
||||||
|
) # 启动核心任务和仪表板服务器
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await task
|
await task # 整个AstrBot在这里运行
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("🌈 正在关闭 AstrBot...")
|
logger.info("🌈 正在关闭 AstrBot...")
|
||||||
await core_lifecycle.stop()
|
await core_lifecycle.stop()
|
||||||
@@ -1,12 +1,38 @@
|
|||||||
|
"""
|
||||||
|
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||||
|
|
||||||
|
const:
|
||||||
|
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
||||||
|
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
|
||||||
|
|
||||||
|
class:
|
||||||
|
LogBroker: 日志代理类, 用于缓存和分发日志消息
|
||||||
|
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
LogManager: 日志管理器, 用于创建和配置日志记录器
|
||||||
|
|
||||||
|
function:
|
||||||
|
is_plugin_path: 检查文件路径是否来自插件目录
|
||||||
|
get_short_level_name: 将日志级别名称转换为四个字母的缩写
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
|
||||||
|
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
|
||||||
|
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
|
||||||
|
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import colorlog
|
import colorlog
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
# 日志缓存大小
|
||||||
CACHED_SIZE = 200
|
CACHED_SIZE = 200
|
||||||
|
# 日志颜色配置
|
||||||
log_color_config = {
|
log_color_config = {
|
||||||
"DEBUG": "green",
|
"DEBUG": "green",
|
||||||
"INFO": "bold_cyan",
|
"INFO": "bold_cyan",
|
||||||
@@ -19,8 +45,13 @@ log_color_config = {
|
|||||||
|
|
||||||
|
|
||||||
def is_plugin_path(pathname):
|
def is_plugin_path(pathname):
|
||||||
"""
|
"""检查文件路径是否来自插件目录
|
||||||
检查文件路径是否来自插件目录
|
|
||||||
|
Args:
|
||||||
|
pathname (str): 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
||||||
"""
|
"""
|
||||||
if not pathname:
|
if not pathname:
|
||||||
return False
|
return False
|
||||||
@@ -30,8 +61,13 @@ def is_plugin_path(pathname):
|
|||||||
|
|
||||||
|
|
||||||
def get_short_level_name(level_name):
|
def get_short_level_name(level_name):
|
||||||
"""
|
"""将日志级别名称转换为四个字母的缩写
|
||||||
将日志级别名称转换为四个字母的缩写
|
|
||||||
|
Args:
|
||||||
|
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 四个字母的日志级别缩写
|
||||||
"""
|
"""
|
||||||
level_map = {
|
level_map = {
|
||||||
"DEBUG": "DBUG",
|
"DEBUG": "DBUG",
|
||||||
@@ -44,12 +80,21 @@ def get_short_level_name(level_name):
|
|||||||
|
|
||||||
|
|
||||||
class LogBroker:
|
class LogBroker:
|
||||||
|
"""日志代理类, 用于缓存和分发日志消息
|
||||||
|
|
||||||
|
发布-订阅模式
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_cache = deque(maxlen=CACHED_SIZE)
|
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||||
self.subscribers: List[Queue] = []
|
self.subscribers: List[Queue] = [] # 订阅者列表
|
||||||
|
|
||||||
def register(self) -> Queue:
|
def register(self) -> Queue:
|
||||||
"""给每个订阅者返回一个带有日志缓存的队列"""
|
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Queue: 订阅者的队列, 可用于接收日志消息
|
||||||
|
"""
|
||||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||||
for log in self.log_cache:
|
for log in self.log_cache:
|
||||||
q.put_nowait(log)
|
q.put_nowait(log)
|
||||||
@@ -57,11 +102,20 @@ class LogBroker:
|
|||||||
return q
|
return q
|
||||||
|
|
||||||
def unregister(self, q: Queue):
|
def unregister(self, q: Queue):
|
||||||
"""取消订阅"""
|
"""取消订阅
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Queue): 需要取消订阅的队列
|
||||||
|
"""
|
||||||
self.subscribers.remove(q)
|
self.subscribers.remove(q)
|
||||||
|
|
||||||
def publish(self, log_entry: str):
|
def publish(self, log_entry: dict):
|
||||||
"""发布消息"""
|
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
||||||
|
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
||||||
|
"""
|
||||||
self.log_cache.append(log_entry)
|
self.log_cache.append(log_entry)
|
||||||
for q in self.subscribers:
|
for q in self.subscribers:
|
||||||
try:
|
try:
|
||||||
@@ -71,24 +125,61 @@ class LogBroker:
|
|||||||
|
|
||||||
|
|
||||||
class LogQueueHandler(logging.Handler):
|
class LogQueueHandler(logging.Handler):
|
||||||
|
"""日志处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
|
||||||
|
继承自 logging.Handler
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, log_broker: LogBroker):
|
def __init__(self, log_broker: LogBroker):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
|
||||||
|
这个方法会在每次日志记录时被调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||||
|
"""
|
||||||
log_entry = self.format(record)
|
log_entry = self.format(record)
|
||||||
self.log_broker.publish(log_entry)
|
self.log_broker.publish(
|
||||||
|
{
|
||||||
|
"level": record.levelname,
|
||||||
|
"time": record.asctime,
|
||||||
|
"data": log_entry,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogManager:
|
class LogManager:
|
||||||
|
"""日志管理器, 用于创建和配置日志记录器
|
||||||
|
|
||||||
|
提供了获取默认日志记录器logger和设置队列处理器的方法
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def GetLogger(cls, log_name: str = "default"):
|
def GetLogger(cls, log_name: str = "default"):
|
||||||
|
"""获取指定名称的日志记录器logger
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_name (str): 日志记录器的名称, 默认为 "default"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: 返回配置好的日志记录器
|
||||||
|
"""
|
||||||
logger = logging.getLogger(log_name)
|
logger = logging.getLogger(log_name)
|
||||||
|
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
||||||
if logger.hasHandlers():
|
if logger.hasHandlers():
|
||||||
return logger
|
return logger
|
||||||
console_handler = logging.StreamHandler()
|
# 如果logger没有处理器
|
||||||
console_handler.setLevel(logging.DEBUG)
|
console_handler = logging.StreamHandler(
|
||||||
|
sys.stdout
|
||||||
|
) # 创建一个StreamHandler用于控制台输出
|
||||||
|
console_handler.setLevel(
|
||||||
|
logging.DEBUG
|
||||||
|
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||||
|
|
||||||
|
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
||||||
console_formatter = colorlog.ColoredFormatter(
|
console_formatter = colorlog.ColoredFormatter(
|
||||||
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
|
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
|
||||||
datefmt="%H:%M:%S",
|
datefmt="%H:%M:%S",
|
||||||
@@ -96,6 +187,8 @@ class LogManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
class PluginFilter(logging.Filter):
|
class PluginFilter(logging.Filter):
|
||||||
|
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
|
||||||
|
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
record.plugin_tag = (
|
record.plugin_tag = (
|
||||||
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
|
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
|
||||||
@@ -103,6 +196,9 @@ class LogManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
class FileNameFilter(logging.Filter):
|
class FileNameFilter(logging.Filter):
|
||||||
|
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
||||||
|
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
|
||||||
|
|
||||||
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
dirname = os.path.dirname(record.pathname)
|
dirname = os.path.dirname(record.pathname)
|
||||||
@@ -114,22 +210,30 @@ class LogManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
class LevelNameFilter(logging.Filter):
|
class LevelNameFilter(logging.Filter):
|
||||||
|
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
|
||||||
|
|
||||||
# 添加短日志级别名称
|
# 添加短日志级别名称
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
record.short_levelname = get_short_level_name(record.levelname)
|
record.short_levelname = get_short_level_name(record.levelname)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
console_handler.setFormatter(console_formatter)
|
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
|
||||||
logger.addFilter(PluginFilter())
|
logger.addFilter(PluginFilter()) # 添加插件过滤器
|
||||||
logger.addFilter(FileNameFilter())
|
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
|
||||||
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
|
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler) # 添加处理器到logger
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
||||||
|
"""设置队列处理器, 用于将日志消息发送到 LogBroker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger (logging.Logger): 日志记录器
|
||||||
|
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
||||||
|
"""
|
||||||
handler = LogQueueHandler(log_broker)
|
handler = LogQueueHandler(log_broker)
|
||||||
handler.setLevel(logging.DEBUG)
|
handler.setLevel(logging.DEBUG)
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
|
|||||||
@@ -22,13 +22,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing as T
|
import typing as T
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
|
|
||||||
|
from astrbot.core import astrbot_config, file_token_service, logger
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||||
|
|
||||||
|
|
||||||
class ComponentType(Enum):
|
class ComponentType(Enum):
|
||||||
Plain = "Plain" # 纯文本消息
|
Plain = "Plain" # 纯文本消息
|
||||||
@@ -59,6 +66,8 @@ class ComponentType(Enum):
|
|||||||
TTS = "TTS"
|
TTS = "TTS"
|
||||||
Unknown = "Unknown"
|
Unknown = "Unknown"
|
||||||
|
|
||||||
|
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageComponent(BaseModel):
|
class BaseMessageComponent(BaseModel):
|
||||||
type: ComponentType
|
type: ComponentType
|
||||||
@@ -93,6 +102,10 @@ class BaseMessageComponent(BaseModel):
|
|||||||
data[k] = v
|
data[k] = v
|
||||||
return {"type": self.type.lower(), "data": data}
|
return {"type": self.type.lower(), "data": data}
|
||||||
|
|
||||||
|
async def to_dict(self) -> dict:
|
||||||
|
# 默认情况下,回退到旧的同步 toDict()
|
||||||
|
return self.toDict()
|
||||||
|
|
||||||
|
|
||||||
class Plain(BaseMessageComponent):
|
class Plain(BaseMessageComponent):
|
||||||
type: ComponentType = "Plain"
|
type: ComponentType = "Plain"
|
||||||
@@ -109,6 +122,9 @@ class Plain(BaseMessageComponent):
|
|||||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def toDict(self):
|
||||||
|
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||||
|
|
||||||
|
|
||||||
class Face(BaseMessageComponent):
|
class Face(BaseMessageComponent):
|
||||||
type: ComponentType = "Face"
|
type: ComponentType = "Face"
|
||||||
@@ -146,6 +162,76 @@ class Record(BaseMessageComponent):
|
|||||||
return Record(file=url, **_)
|
return Record(file=url, **_)
|
||||||
raise Exception("not a valid url")
|
raise Exception("not a valid url")
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 语音的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
if self.file and self.file.startswith("file:///"):
|
||||||
|
file_path = self.file[8:]
|
||||||
|
return file_path
|
||||||
|
elif self.file and self.file.startswith("http"):
|
||||||
|
file_path = await download_image_by_url(self.file)
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
elif self.file and self.file.startswith("base64://"):
|
||||||
|
bs64_data = self.file.removeprefix("base64://")
|
||||||
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
elif os.path.exists(self.file):
|
||||||
|
file_path = self.file
|
||||||
|
return os.path.abspath(file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
|
||||||
|
async def convert_to_base64(self) -> str:
|
||||||
|
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
"""
|
||||||
|
# convert to base64
|
||||||
|
if self.file and self.file.startswith("file:///"):
|
||||||
|
bs64_data = file_to_base64(self.file[8:])
|
||||||
|
elif self.file and self.file.startswith("http"):
|
||||||
|
file_path = await download_image_by_url(self.file)
|
||||||
|
bs64_data = file_to_base64(file_path)
|
||||||
|
elif self.file and self.file.startswith("base64://"):
|
||||||
|
bs64_data = self.file
|
||||||
|
elif os.path.exists(self.file):
|
||||||
|
bs64_data = file_to_base64(self.file)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
bs64_data = bs64_data.removeprefix("base64://")
|
||||||
|
return bs64_data
|
||||||
|
|
||||||
|
async def register_to_file_service(self) -> str:
|
||||||
|
"""
|
||||||
|
将语音注册到文件服务。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 注册后的URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果未配置 callback_api_base
|
||||||
|
"""
|
||||||
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
|
if not callback_host:
|
||||||
|
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||||
|
|
||||||
|
file_path = await self.convert_to_file_path()
|
||||||
|
|
||||||
|
token = await file_token_service.register_file(file_path)
|
||||||
|
|
||||||
|
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||||
|
|
||||||
|
return f"{callback_host}/api/file/{token}"
|
||||||
|
|
||||||
|
|
||||||
class Video(BaseMessageComponent):
|
class Video(BaseMessageComponent):
|
||||||
type: ComponentType = "Video"
|
type: ComponentType = "Video"
|
||||||
@@ -156,9 +242,6 @@ class Video(BaseMessageComponent):
|
|||||||
path: T.Optional[str] = ""
|
path: T.Optional[str] = ""
|
||||||
|
|
||||||
def __init__(self, file: str, **_):
|
def __init__(self, file: str, **_):
|
||||||
# for k in _.keys():
|
|
||||||
# if k == "c" and _[k] not in [2, 3]:
|
|
||||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
|
||||||
super().__init__(file=file, **_)
|
super().__init__(file=file, **_)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -171,6 +254,70 @@ class Video(BaseMessageComponent):
|
|||||||
return Video(file=url, **_)
|
return Video(file=url, **_)
|
||||||
raise Exception("not a valid url")
|
raise Exception("not a valid url")
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 视频的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
url = self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
return url[8:]
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||||
|
await download_file(url, video_file_path)
|
||||||
|
if os.path.exists(video_file_path):
|
||||||
|
return os.path.abspath(video_file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"download failed: {url}")
|
||||||
|
elif os.path.exists(url):
|
||||||
|
return os.path.abspath(url)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
|
||||||
|
async def register_to_file_service(self):
|
||||||
|
"""
|
||||||
|
将视频注册到文件服务。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 注册后的URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果未配置 callback_api_base
|
||||||
|
"""
|
||||||
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
|
if not callback_host:
|
||||||
|
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||||
|
|
||||||
|
file_path = await self.convert_to_file_path()
|
||||||
|
|
||||||
|
token = await file_token_service.register_file(file_path)
|
||||||
|
|
||||||
|
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||||
|
|
||||||
|
return f"{callback_host}/api/file/{token}"
|
||||||
|
|
||||||
|
async def to_dict(self):
|
||||||
|
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||||
|
url_or_path = self.file
|
||||||
|
if url_or_path.startswith("http"):
|
||||||
|
payload_file = url_or_path
|
||||||
|
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||||
|
callback_host = str(callback_host).removesuffix("/")
|
||||||
|
token = await file_token_service.register_file(url_or_path)
|
||||||
|
payload_file = f"{callback_host}/api/file/{token}"
|
||||||
|
logger.debug(f"Generated video file callback link: {payload_file}")
|
||||||
|
else:
|
||||||
|
payload_file = url_or_path
|
||||||
|
return {
|
||||||
|
"type": "video",
|
||||||
|
"data": {
|
||||||
|
"file": payload_file,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class At(BaseMessageComponent):
|
class At(BaseMessageComponent):
|
||||||
type: ComponentType = "At"
|
type: ComponentType = "At"
|
||||||
@@ -180,6 +327,12 @@ class At(BaseMessageComponent):
|
|||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
def toDict(self):
|
||||||
|
return {
|
||||||
|
"type": "at",
|
||||||
|
"data": {"qq": str(self.qq)},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AtAll(At):
|
class AtAll(At):
|
||||||
qq: str = "all"
|
qq: str = "all"
|
||||||
@@ -279,10 +432,6 @@ class Image(BaseMessageComponent):
|
|||||||
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
||||||
|
|
||||||
def __init__(self, file: T.Optional[str], **_):
|
def __init__(self, file: T.Optional[str], **_):
|
||||||
# for k in _.keys():
|
|
||||||
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
|
|
||||||
# (k == "c" and _[k] not in [2, 3]):
|
|
||||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
|
||||||
super().__init__(file=file, **_)
|
super().__init__(file=file, **_)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -307,14 +456,100 @@ class Image(BaseMessageComponent):
|
|||||||
def fromIO(IO):
|
def fromIO(IO):
|
||||||
return Image.fromBytes(IO.read())
|
return Image.fromBytes(IO.read())
|
||||||
|
|
||||||
|
async def convert_to_file_path(self) -> str:
|
||||||
|
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 图片的本地路径,以绝对路径表示。
|
||||||
|
"""
|
||||||
|
url = self.url if self.url else self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
image_file_path = url[8:]
|
||||||
|
return image_file_path
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(url)
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
elif url and url.startswith("base64://"):
|
||||||
|
bs64_data = url.removeprefix("base64://")
|
||||||
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||||
|
with open(image_file_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
elif os.path.exists(url):
|
||||||
|
image_file_path = url
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
|
||||||
|
async def convert_to_base64(self) -> str:
|
||||||
|
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
"""
|
||||||
|
# convert to base64
|
||||||
|
url = self.url if self.url else self.file
|
||||||
|
if url and url.startswith("file:///"):
|
||||||
|
bs64_data = file_to_base64(url[8:])
|
||||||
|
elif url and url.startswith("http"):
|
||||||
|
image_file_path = await download_image_by_url(url)
|
||||||
|
bs64_data = file_to_base64(image_file_path)
|
||||||
|
elif url and url.startswith("base64://"):
|
||||||
|
bs64_data = url
|
||||||
|
elif os.path.exists(url):
|
||||||
|
bs64_data = file_to_base64(url)
|
||||||
|
else:
|
||||||
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
bs64_data = bs64_data.removeprefix("base64://")
|
||||||
|
return bs64_data
|
||||||
|
|
||||||
|
async def register_to_file_service(self) -> str:
|
||||||
|
"""
|
||||||
|
将图片注册到文件服务。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 注册后的URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果未配置 callback_api_base
|
||||||
|
"""
|
||||||
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
|
if not callback_host:
|
||||||
|
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||||
|
|
||||||
|
file_path = await self.convert_to_file_path()
|
||||||
|
|
||||||
|
token = await file_token_service.register_file(file_path)
|
||||||
|
|
||||||
|
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||||
|
|
||||||
|
return f"{callback_host}/api/file/{token}"
|
||||||
|
|
||||||
|
|
||||||
class Reply(BaseMessageComponent):
|
class Reply(BaseMessageComponent):
|
||||||
type: ComponentType = "Reply"
|
type: ComponentType = "Reply"
|
||||||
id: T.Union[str, int]
|
id: T.Union[str, int]
|
||||||
text: T.Optional[str] = ""
|
"""所引用的消息 ID"""
|
||||||
qq: T.Optional[int] = 0
|
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||||
|
"""被引用的消息段列表"""
|
||||||
|
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||||
|
"""被引用的消息对应的发送者的 ID"""
|
||||||
|
sender_nickname: T.Optional[str] = ""
|
||||||
|
"""被引用的消息对应的发送者的昵称"""
|
||||||
time: T.Optional[int] = 0
|
time: T.Optional[int] = 0
|
||||||
|
"""被引用的消息发送时间"""
|
||||||
|
message_str: T.Optional[str] = ""
|
||||||
|
"""被引用的消息解析后的纯文本消息字符串"""
|
||||||
|
|
||||||
|
text: T.Optional[str] = ""
|
||||||
|
"""deprecated"""
|
||||||
|
qq: T.Optional[int] = 0
|
||||||
|
"""deprecated"""
|
||||||
seq: T.Optional[int] = 0
|
seq: T.Optional[int] = 0
|
||||||
|
"""deprecated"""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -352,22 +587,48 @@ class Node(BaseMessageComponent):
|
|||||||
type: ComponentType = "Node"
|
type: ComponentType = "Node"
|
||||||
id: T.Optional[int] = 0 # 忽略
|
id: T.Optional[int] = 0 # 忽略
|
||||||
name: T.Optional[str] = "" # qq昵称
|
name: T.Optional[str] = "" # qq昵称
|
||||||
uin: T.Optional[int] = 0 # qq号
|
uin: T.Optional[str] = "0" # qq号
|
||||||
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
|
content: T.Optional[list[BaseMessageComponent]] = []
|
||||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||||
time: T.Optional[int] = 0
|
time: T.Optional[int] = 0 # 忽略
|
||||||
|
|
||||||
def __init__(self, content: T.Union[str, list], **_):
|
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||||
if isinstance(content, list):
|
if isinstance(content, Node):
|
||||||
_content = ""
|
# back
|
||||||
for chain in content:
|
content = [content]
|
||||||
_content += chain.toString()
|
|
||||||
content = _content
|
|
||||||
super().__init__(content=content, **_)
|
super().__init__(content=content, **_)
|
||||||
|
|
||||||
def toString(self):
|
async def to_dict(self):
|
||||||
# logger.warn("Protocol: node doesn't support stringify")
|
data_content = []
|
||||||
return ""
|
for comp in self.content:
|
||||||
|
if isinstance(comp, (Image, Record)):
|
||||||
|
# For Image and Record segments, we convert them to base64
|
||||||
|
bs64 = await comp.convert_to_base64()
|
||||||
|
data_content.append(
|
||||||
|
{
|
||||||
|
"type": comp.type.lower(),
|
||||||
|
"data": {"file": f"base64://{bs64}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(comp, File):
|
||||||
|
# For File segments, we need to handle the file differently
|
||||||
|
d = await comp.to_dict()
|
||||||
|
data_content.append(d)
|
||||||
|
elif isinstance(comp, (Node, Nodes)):
|
||||||
|
# For Node segments, we recursively convert them to dict
|
||||||
|
d = await comp.to_dict()
|
||||||
|
data_content.append(d)
|
||||||
|
else:
|
||||||
|
d = comp.toDict()
|
||||||
|
data_content.append(d)
|
||||||
|
return {
|
||||||
|
"type": "node",
|
||||||
|
"data": {
|
||||||
|
"user_id": str(self.uin),
|
||||||
|
"nickname": self.name,
|
||||||
|
"content": data_content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Nodes(BaseMessageComponent):
|
class Nodes(BaseMessageComponent):
|
||||||
@@ -378,7 +639,22 @@ class Nodes(BaseMessageComponent):
|
|||||||
super().__init__(nodes=nodes, **_)
|
super().__init__(nodes=nodes, **_)
|
||||||
|
|
||||||
def toDict(self):
|
def toDict(self):
|
||||||
return {"messages": [node.toDict() for node in self.nodes]}
|
"""Deprecated. Use to_dict instead"""
|
||||||
|
ret = {
|
||||||
|
"messages": [],
|
||||||
|
}
|
||||||
|
for node in self.nodes:
|
||||||
|
d = node.toDict()
|
||||||
|
ret["messages"].append(d)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def to_dict(self):
|
||||||
|
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||||
|
ret = {"messages": []}
|
||||||
|
for node in self.nodes:
|
||||||
|
d = await node.to_dict()
|
||||||
|
ret["messages"].append(d)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class Xml(BaseMessageComponent):
|
class Xml(BaseMessageComponent):
|
||||||
@@ -438,15 +714,146 @@ class Unknown(BaseMessageComponent):
|
|||||||
|
|
||||||
class File(BaseMessageComponent):
|
class File(BaseMessageComponent):
|
||||||
"""
|
"""
|
||||||
目前此消息段只适配了 Napcat。
|
文件消息段
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ComponentType = "File"
|
type: ComponentType = "File"
|
||||||
name: T.Optional[str] = "" # 名字
|
name: T.Optional[str] = "" # 名字
|
||||||
file: T.Optional[str] = "" # url(本地路径)
|
file_: T.Optional[str] = "" # 本地路径
|
||||||
|
url: T.Optional[str] = "" # url
|
||||||
|
|
||||||
def __init__(self, name: str, file: str):
|
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||||
super().__init__(name=name, file=file)
|
"""文件消息段。"""
|
||||||
|
super().__init__(name=name, file_=file, url=url)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file(self) -> str:
|
||||||
|
"""
|
||||||
|
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 文件路径
|
||||||
|
"""
|
||||||
|
if self.file_ and os.path.exists(self.file_):
|
||||||
|
return os.path.abspath(self.file_)
|
||||||
|
|
||||||
|
if self.url:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
logger.warning(
|
||||||
|
(
|
||||||
|
"不可以在异步上下文中同步等待下载! "
|
||||||
|
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||||
|
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
# 等待下载完成
|
||||||
|
loop.run_until_complete(self._download_file())
|
||||||
|
|
||||||
|
if self.file_ and os.path.exists(self.file_):
|
||||||
|
return os.path.abspath(self.file_)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文件下载失败: {e}")
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@file.setter
|
||||||
|
def file(self, value: str):
|
||||||
|
"""
|
||||||
|
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (str): 文件路径或URL
|
||||||
|
"""
|
||||||
|
if value.startswith("http://") or value.startswith("https://"):
|
||||||
|
self.url = value
|
||||||
|
else:
|
||||||
|
self.file_ = value
|
||||||
|
|
||||||
|
async def get_file(self, allow_return_url: bool = False) -> str:
|
||||||
|
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
|
||||||
|
注意,如果为 True,也可能返回文件路径。
|
||||||
|
Returns:
|
||||||
|
str: 文件路径或者 http 下载链接
|
||||||
|
"""
|
||||||
|
if allow_return_url and self.url:
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
if self.file_ and os.path.exists(self.file_):
|
||||||
|
return os.path.abspath(self.file_)
|
||||||
|
|
||||||
|
if self.url:
|
||||||
|
await self._download_file()
|
||||||
|
return os.path.abspath(self.file_)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def _download_file(self):
|
||||||
|
"""下载文件"""
|
||||||
|
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
os.makedirs(download_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||||
|
await download_file(self.url, file_path)
|
||||||
|
self.file_ = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
async def register_to_file_service(self):
|
||||||
|
"""
|
||||||
|
将文件注册到文件服务。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 注册后的URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果未配置 callback_api_base
|
||||||
|
"""
|
||||||
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
|
if not callback_host:
|
||||||
|
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||||
|
|
||||||
|
file_path = await self.get_file()
|
||||||
|
|
||||||
|
token = await file_token_service.register_file(file_path)
|
||||||
|
|
||||||
|
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||||
|
|
||||||
|
return f"{callback_host}/api/file/{token}"
|
||||||
|
|
||||||
|
async def to_dict(self):
|
||||||
|
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||||
|
url_or_path = await self.get_file(allow_return_url=True)
|
||||||
|
if url_or_path.startswith("http"):
|
||||||
|
payload_file = url_or_path
|
||||||
|
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||||
|
callback_host = str(callback_host).removesuffix("/")
|
||||||
|
token = await file_token_service.register_file(url_or_path)
|
||||||
|
payload_file = f"{callback_host}/api/file/{token}"
|
||||||
|
logger.debug(f"Generated file callback link: {payload_file}")
|
||||||
|
else:
|
||||||
|
payload_file = url_or_path
|
||||||
|
return {
|
||||||
|
"type": "file",
|
||||||
|
"data": {
|
||||||
|
"name": self.name,
|
||||||
|
"file": payload_file,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WechatEmoji(BaseMessageComponent):
|
||||||
|
type: ComponentType = "WechatEmoji"
|
||||||
|
md5: T.Optional[str] = ""
|
||||||
|
md5_len: T.Optional[int] = 0
|
||||||
|
cdnurl: T.Optional[str] = ""
|
||||||
|
|
||||||
|
def __init__(self, **_):
|
||||||
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
ComponentTypes = {
|
ComponentTypes = {
|
||||||
@@ -477,4 +884,5 @@ ComponentTypes = {
|
|||||||
"tts": TTS,
|
"tts": TTS,
|
||||||
"unknown": Unknown,
|
"unknown": Unknown,
|
||||||
"file": File,
|
"file": File,
|
||||||
|
"WechatEmoji": WechatEmoji,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
import enum
|
import enum
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union, AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
|
from astrbot.core.message.components import (
|
||||||
|
BaseMessageComponent,
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
At,
|
||||||
|
AtAll,
|
||||||
|
)
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +37,30 @@ class MessageChain:
|
|||||||
self.chain.append(Plain(message))
|
self.chain.append(Plain(message))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def at(self, name: str, qq: Union[str, int]):
|
||||||
|
"""添加一条 At 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().at("张三", "12345678910")
|
||||||
|
# 输出 @张三
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chain.append(At(name=name, qq=qq))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def at_all(self):
|
||||||
|
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
CommandResult().at_all()
|
||||||
|
# 输出 @所有人
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chain.append(AtAll())
|
||||||
|
return self
|
||||||
|
|
||||||
@deprecated("请使用 message 方法代替。")
|
@deprecated("请使用 message 方法代替。")
|
||||||
def error(self, message: str):
|
def error(self, message: str):
|
||||||
"""添加一条错误消息到消息链 `chain` 中
|
"""添加一条错误消息到消息链 `chain` 中
|
||||||
@@ -77,6 +107,34 @@ class MessageChain:
|
|||||||
self.use_t2i_ = use_t2i
|
self.use_t2i_ = use_t2i
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_plain_text(self) -> str:
|
||||||
|
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||||
|
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||||
|
|
||||||
|
def squash_plain(self):
|
||||||
|
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||||
|
if not self.chain:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_chain = []
|
||||||
|
first_plain = None
|
||||||
|
plain_texts = []
|
||||||
|
|
||||||
|
for comp in self.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
if first_plain is None:
|
||||||
|
first_plain = comp
|
||||||
|
new_chain.append(comp)
|
||||||
|
plain_texts.append(comp.text)
|
||||||
|
else:
|
||||||
|
new_chain.append(comp)
|
||||||
|
|
||||||
|
if first_plain is not None:
|
||||||
|
first_plain.text = "".join(plain_texts)
|
||||||
|
|
||||||
|
self.chain = new_chain
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class EventResultType(enum.Enum):
|
class EventResultType(enum.Enum):
|
||||||
"""用于描述事件处理的结果类型。
|
"""用于描述事件处理的结果类型。
|
||||||
@@ -97,6 +155,10 @@ class ResultContentType(enum.Enum):
|
|||||||
"""调用 LLM 产生的结果"""
|
"""调用 LLM 产生的结果"""
|
||||||
GENERAL_RESULT = enum.auto()
|
GENERAL_RESULT = enum.auto()
|
||||||
"""普通的消息结果"""
|
"""普通的消息结果"""
|
||||||
|
STREAMING_RESULT = enum.auto()
|
||||||
|
"""调用 LLM 产生的流式结果"""
|
||||||
|
STREAMING_FINISH= enum.auto()
|
||||||
|
"""流式输出完成"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -118,6 +180,9 @@ class MessageEventResult(MessageChain):
|
|||||||
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async_stream: Optional[AsyncGenerator] = None
|
||||||
|
"""异步流"""
|
||||||
|
|
||||||
def stop_event(self) -> "MessageEventResult":
|
def stop_event(self) -> "MessageEventResult":
|
||||||
"""终止事件传播。"""
|
"""终止事件传播。"""
|
||||||
self.result_type = EventResultType.STOP
|
self.result_type = EventResultType.STOP
|
||||||
@@ -134,6 +199,11 @@ class MessageEventResult(MessageChain):
|
|||||||
"""
|
"""
|
||||||
return self.result_type == EventResultType.STOP
|
return self.result_type == EventResultType.STOP
|
||||||
|
|
||||||
|
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||||
|
"""设置异步流。"""
|
||||||
|
self.async_stream = stream
|
||||||
|
return self
|
||||||
|
|
||||||
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
||||||
"""设置事件处理的结果类型。
|
"""设置事件处理的结果类型。
|
||||||
|
|
||||||
@@ -147,9 +217,6 @@ class MessageEventResult(MessageChain):
|
|||||||
"""是否为 LLM 结果。"""
|
"""是否为 LLM 结果。"""
|
||||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||||
|
|
||||||
def get_plain_text(self) -> str:
|
|
||||||
"""获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
|
||||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
|
||||||
|
|
||||||
|
|
||||||
|
# 为了兼容旧版代码,保留 CommandResult 的别名
|
||||||
CommandResult = MessageEventResult
|
CommandResult = MessageEventResult
|
||||||
|
|||||||
@@ -7,16 +7,19 @@ from .waking_check.stage import WakingCheckStage
|
|||||||
from .whitelist_check.stage import WhitelistCheckStage
|
from .whitelist_check.stage import WhitelistCheckStage
|
||||||
from .rate_limit_check.stage import RateLimitStage
|
from .rate_limit_check.stage import RateLimitStage
|
||||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||||
|
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||||
from .preprocess_stage.stage import PreProcessStage
|
from .preprocess_stage.stage import PreProcessStage
|
||||||
from .process_stage.stage import ProcessStage
|
from .process_stage.stage import ProcessStage
|
||||||
from .result_decorate.stage import ResultDecorateStage
|
from .result_decorate.stage import ResultDecorateStage
|
||||||
from .respond.stage import RespondStage
|
from .respond.stage import RespondStage
|
||||||
|
|
||||||
|
# 管道阶段顺序
|
||||||
STAGES_ORDER = [
|
STAGES_ORDER = [
|
||||||
"WakingCheckStage", # 检查是否需要唤醒
|
"WakingCheckStage", # 检查是否需要唤醒
|
||||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||||
"RateLimitStage", # 检查会话是否超过频率限制
|
"RateLimitStage", # 检查会话是否超过频率限制
|
||||||
"ContentSafetyCheckStage", # 检查内容安全
|
"ContentSafetyCheckStage", # 检查内容安全
|
||||||
|
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||||
"PreProcessStage", # 预处理
|
"PreProcessStage", # 预处理
|
||||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||||
@@ -28,6 +31,7 @@ __all__ = [
|
|||||||
"WhitelistCheckStage",
|
"WhitelistCheckStage",
|
||||||
"RateLimitStage",
|
"RateLimitStage",
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
|
"PlatformCompatibilityStage",
|
||||||
"PreProcessStage",
|
"PreProcessStage",
|
||||||
"ProcessStage",
|
"ProcessStage",
|
||||||
"ResultDecorateStage",
|
"ResultDecorateStage",
|
||||||
|
|||||||
@@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineContext:
|
class PipelineContext:
|
||||||
astrbot_config: AstrBotConfig
|
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||||
plugin_manager: PluginManager
|
|
||||||
|
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||||
|
plugin_manager: PluginManager # 插件管理器对象
|
||||||
|
|||||||
56
astrbot/core/pipeline/platform_compatibility/stage.py
Normal file
56
astrbot/core/pipeline/platform_compatibility/stage.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from ..stage import Stage, register_stage
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from typing import Union, AsyncGenerator
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
@register_stage
|
||||||
|
class PlatformCompatibilityStage(Stage):
|
||||||
|
"""检查所有处理器的平台兼容性。
|
||||||
|
|
||||||
|
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
"""初始化平台兼容性检查阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self, event: AstrMessageEvent
|
||||||
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
# 获取当前平台ID
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
|
||||||
|
# 获取已激活的处理器
|
||||||
|
activated_handlers = event.get_extra("activated_handlers")
|
||||||
|
if activated_handlers is None:
|
||||||
|
activated_handlers = []
|
||||||
|
|
||||||
|
# 标记不兼容的处理器
|
||||||
|
for handler in activated_handlers:
|
||||||
|
if not isinstance(handler, StarHandlerMetadata):
|
||||||
|
continue
|
||||||
|
# 检查处理器是否在当前平台启用
|
||||||
|
enabled = handler.is_enabled_for_platform(platform_id)
|
||||||
|
if not enabled:
|
||||||
|
if handler.handler_module_path in star_map:
|
||||||
|
plugin_name = star_map[handler.handler_module_path].name
|
||||||
|
logger.debug(
|
||||||
|
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||||
|
)
|
||||||
|
# 设置处理器为平台不兼容状态
|
||||||
|
# TODO: 更好的标记方式
|
||||||
|
handler.platform_compatible = False
|
||||||
|
else:
|
||||||
|
# 确保处理器为平台兼容状态
|
||||||
|
handler.platform_compatible = True
|
||||||
|
|
||||||
|
# 更新已激活的处理器列表
|
||||||
|
event.set_extra("activated_handlers", activated_handlers)
|
||||||
@@ -46,28 +46,29 @@ class PreProcessStage(Stage):
|
|||||||
stt_provider = (
|
stt_provider = (
|
||||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||||
)
|
)
|
||||||
if stt_provider:
|
if not stt_provider:
|
||||||
message_chain = event.get_messages()
|
return
|
||||||
for idx, component in enumerate(message_chain):
|
message_chain = event.get_messages()
|
||||||
if isinstance(component, Record) and component.url:
|
for idx, component in enumerate(message_chain):
|
||||||
path = component.url.removeprefix("file://")
|
if isinstance(component, Record) and component.url:
|
||||||
retry = 5
|
path = component.url.removeprefix("file://")
|
||||||
for i in range(retry):
|
retry = 5
|
||||||
try:
|
for i in range(retry):
|
||||||
result = await stt_provider.get_text(audio_url=path)
|
try:
|
||||||
if result:
|
result = await stt_provider.get_text(audio_url=path)
|
||||||
logger.info("语音转文本结果: " + result)
|
if result:
|
||||||
message_chain[idx] = Plain(result)
|
logger.info("语音转文本结果: " + result)
|
||||||
event.message_str += result
|
message_chain[idx] = Plain(result)
|
||||||
event.message_obj.message_str += result
|
event.message_str += result
|
||||||
break
|
event.message_obj.message_str += result
|
||||||
except FileNotFoundError as e:
|
break
|
||||||
# napcat workaround
|
except FileNotFoundError as e:
|
||||||
logger.warning(e)
|
# napcat workaround
|
||||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
logger.warning(e)
|
||||||
await asyncio.sleep(0.5)
|
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||||
continue
|
await asyncio.sleep(0.5)
|
||||||
except BaseException as e:
|
continue
|
||||||
logger.error(traceback.format_exc())
|
except BaseException as e:
|
||||||
logger.error(f"语音转文本失败: {e}")
|
logger.error(traceback.format_exc())
|
||||||
break
|
logger.error(f"语音转文本失败: {e}")
|
||||||
|
break
|
||||||
|
|||||||
@@ -12,13 +12,27 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
ResultContentType,
|
ResultContentType,
|
||||||
|
MessageChain,
|
||||||
)
|
)
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
|
from astrbot.core.provider.entities import (
|
||||||
|
ProviderRequest,
|
||||||
|
LLMResponse,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
AssistantMessageSegment,
|
||||||
|
ToolCallsResult,
|
||||||
|
)
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
|
from mcp.types import (
|
||||||
|
TextContent,
|
||||||
|
ImageContent,
|
||||||
|
EmbeddedResource,
|
||||||
|
TextResourceContents,
|
||||||
|
BlobResourceContents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestSubStage(Stage):
|
class LLMRequestSubStage(Stage):
|
||||||
@@ -28,6 +42,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
||||||
"wake_prefix"
|
"wake_prefix"
|
||||||
] # str
|
] # str
|
||||||
|
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||||
|
"max_context_length"
|
||||||
|
] # int
|
||||||
|
self.dequeue_context_length = min(
|
||||||
|
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||||
|
self.max_context_length - 1,
|
||||||
|
) # int
|
||||||
|
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||||
|
"streaming_response"
|
||||||
|
] # bool
|
||||||
|
|
||||||
for bwp in self.bot_wake_prefixs:
|
for bwp in self.bot_wake_prefixs:
|
||||||
if self.provider_wake_prefix.startswith(bwp):
|
if self.provider_wake_prefix.startswith(bwp):
|
||||||
@@ -43,6 +67,10 @@ class LLMRequestSubStage(Stage):
|
|||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
req: ProviderRequest = None
|
req: ProviderRequest = None
|
||||||
|
|
||||||
|
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||||
|
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||||
|
return
|
||||||
|
|
||||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||||
if provider is None:
|
if provider is None:
|
||||||
return
|
return
|
||||||
@@ -54,7 +82,11 @@ class LLMRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if req.conversation:
|
if req.conversation:
|
||||||
req.contexts = json.loads(req.conversation.history)
|
all_contexts = json.loads(req.conversation.history)
|
||||||
|
req.contexts = self._process_tool_message_pairs(
|
||||||
|
all_contexts, remove_tags=True
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
req = ProviderRequest(prompt="", image_urls=[])
|
req = ProviderRequest(prompt="", image_urls=[])
|
||||||
if self.provider_wake_prefix:
|
if self.provider_wake_prefix:
|
||||||
@@ -64,8 +96,8 @@ class LLMRequestSubStage(Stage):
|
|||||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||||
for comp in event.message_obj.message:
|
for comp in event.message_obj.message:
|
||||||
if isinstance(comp, Image):
|
if isinstance(comp, Image):
|
||||||
image_url = comp.url if comp.url else comp.file
|
image_path = await comp.convert_to_file_path()
|
||||||
req.image_urls.append(image_url)
|
req.image_urls.append(image_path)
|
||||||
|
|
||||||
# 获取对话上下文
|
# 获取对话上下文
|
||||||
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
||||||
@@ -75,10 +107,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
conversation_id = await self.conv_manager.new_conversation(
|
conversation_id = await self.conv_manager.new_conversation(
|
||||||
event.unified_msg_origin
|
event.unified_msg_origin
|
||||||
)
|
)
|
||||||
req.session_id = event.unified_msg_origin
|
|
||||||
conversation = await self.conv_manager.get_conversation(
|
conversation = await self.conv_manager.get_conversation(
|
||||||
event.unified_msg_origin, conversation_id
|
event.unified_msg_origin, conversation_id
|
||||||
)
|
)
|
||||||
|
if not conversation:
|
||||||
|
conversation_id = await self.conv_manager.new_conversation(
|
||||||
|
event.unified_msg_origin
|
||||||
|
)
|
||||||
|
conversation = await self.conv_manager.get_conversation(
|
||||||
|
event.unified_msg_origin, conversation_id
|
||||||
|
)
|
||||||
req.conversation = conversation
|
req.conversation = conversation
|
||||||
req.contexts = json.loads(conversation.history)
|
req.contexts = json.loads(conversation.history)
|
||||||
|
|
||||||
@@ -89,8 +127,10 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
# 执行请求 LLM 前事件钩子。
|
# 执行请求 LLM 前事件钩子。
|
||||||
# 装饰 system_prompt 等功能
|
# 装饰 system_prompt 等功能
|
||||||
|
# 获取当前平台ID
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnLLMRequestEvent
|
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
@@ -110,110 +150,373 @@ class LLMRequestSubStage(Stage):
|
|||||||
if isinstance(req.contexts, str):
|
if isinstance(req.contexts, str):
|
||||||
req.contexts = json.loads(req.contexts)
|
req.contexts = json.loads(req.contexts)
|
||||||
|
|
||||||
try:
|
# max context length
|
||||||
logger.debug(f"提供商请求 Payload: {req}")
|
if (
|
||||||
if _nested:
|
self.max_context_length != -1 # -1 为不限制
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
and len(req.contexts) // 2 > self.max_context_length
|
||||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
):
|
||||||
|
logger.debug("上下文长度超过限制,将截断。")
|
||||||
# 执行 LLM 响应后的事件钩子。
|
req.contexts = req.contexts[
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||||
EventType.OnLLMResponseEvent
|
]
|
||||||
|
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||||
|
index = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, item in enumerate(req.contexts)
|
||||||
|
if item.get("role") == "user"
|
||||||
|
),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
if index is not None and index > 0:
|
||||||
try:
|
req.contexts = req.contexts[index:]
|
||||||
logger.debug(
|
|
||||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
# session_id
|
||||||
|
if not req.session_id:
|
||||||
|
req.session_id = event.unified_msg_origin
|
||||||
|
|
||||||
|
async def requesting(req: ProviderRequest):
|
||||||
|
try:
|
||||||
|
need_loop = True
|
||||||
|
while need_loop:
|
||||||
|
need_loop = False
|
||||||
|
logger.debug(f"提供商请求 Payload: {req}")
|
||||||
|
|
||||||
|
final_llm_response = None
|
||||||
|
|
||||||
|
if self.streaming_response:
|
||||||
|
stream = provider.text_chat_stream(**req.__dict__)
|
||||||
|
async for llm_response in stream:
|
||||||
|
if llm_response.is_chunk:
|
||||||
|
if llm_response.result_chain:
|
||||||
|
yield llm_response.result_chain # MessageChain
|
||||||
|
else:
|
||||||
|
yield MessageChain().message(
|
||||||
|
llm_response.completion_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_llm_response = llm_response
|
||||||
|
else:
|
||||||
|
final_llm_response = await provider.text_chat(
|
||||||
|
**req.__dict__
|
||||||
|
) # 请求 LLM
|
||||||
|
|
||||||
|
if not final_llm_response:
|
||||||
|
raise Exception("LLM response is None.")
|
||||||
|
|
||||||
|
# 执行 LLM 响应后的事件钩子。
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnLLMResponseEvent
|
||||||
)
|
)
|
||||||
await handler.handler(event, llm_response)
|
for handler in handlers:
|
||||||
except BaseException:
|
try:
|
||||||
logger.error(traceback.format_exc())
|
logger.debug(
|
||||||
|
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler(event, final_llm_response)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.streaming_response:
|
||||||
|
# 流式输出的处理
|
||||||
|
async for result in self._handle_llm_stream_response(
|
||||||
|
event, req, final_llm_response
|
||||||
|
):
|
||||||
|
if isinstance(result, ProviderRequest):
|
||||||
|
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||||
|
req = result
|
||||||
|
need_loop = True
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
# 非流式输出的处理
|
||||||
|
async for result in self._handle_llm_response(
|
||||||
|
event, req, final_llm_response
|
||||||
|
):
|
||||||
|
if isinstance(result, ProviderRequest):
|
||||||
|
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||||
|
req = result
|
||||||
|
need_loop = True
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
llm_tick=1,
|
||||||
|
model_name=provider.get_model(),
|
||||||
|
provider_type=provider.meta().type,
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
# 保存到历史记录
|
|
||||||
await self._save_to_history(event, req, llm_response)
|
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
Metric.upload(
|
|
||||||
llm_tick=1,
|
|
||||||
model_name=provider.get_model(),
|
|
||||||
provider_type=provider.meta().type,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if llm_response.role == "assistant":
|
# 保存到历史记录
|
||||||
# text completion
|
await self._save_to_history(event, req, final_llm_response)
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.streaming_response:
|
||||||
|
event.set_extra("tool_call_result", None)
|
||||||
|
async for _ in requesting(req):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||||
|
.set_async_stream(requesting(req))
|
||||||
|
)
|
||||||
|
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||||
|
yield
|
||||||
|
|
||||||
|
if event.get_extra("tool_call_result"):
|
||||||
|
event.set_result(event.get_extra("tool_call_result"))
|
||||||
|
event.set_extra("tool_call_result", None)
|
||||||
|
yield
|
||||||
|
|
||||||
|
# 暂时直接发出去
|
||||||
|
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||||
|
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||||
|
event.set_extra("tool_call_img_respond", None)
|
||||||
|
yield
|
||||||
|
|
||||||
|
async def _handle_llm_response(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理非流式 LLM 响应。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# text completion
|
||||||
|
if llm_response.result_chain:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult(
|
||||||
|
chain=llm_response.result_chain.chain
|
||||||
|
).set_result_content_type(ResultContentType.LLM_RESULT)
|
||||||
|
)
|
||||||
|
else:
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult()
|
MessageEventResult()
|
||||||
.message(llm_response.completion_text)
|
.message(llm_response.completion_text)
|
||||||
.set_result_content_type(ResultContentType.LLM_RESULT)
|
.set_result_content_type(ResultContentType.LLM_RESULT)
|
||||||
)
|
)
|
||||||
elif llm_response.role == "err":
|
elif llm_response.role == "err":
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "tool":
|
||||||
|
# 处理函数工具调用
|
||||||
|
async for result in self._handle_function_tools(event, req, llm_response):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
async def _handle_llm_stream_response(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理流式 LLM 响应。
|
||||||
|
|
||||||
|
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
if llm_response.role == "assistant":
|
||||||
|
# text completion
|
||||||
|
if llm_response.result_chain:
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult().message(
|
MessageEventResult(
|
||||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
chain=llm_response.result_chain.chain
|
||||||
|
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult()
|
||||||
|
.message(llm_response.completion_text)
|
||||||
|
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "err":
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif llm_response.role == "tool":
|
||||||
|
# 处理函数工具调用
|
||||||
|
async for result in self._handle_function_tools(event, req, llm_response):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
async def _handle_function_tools(
|
||||||
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
llm_response: LLMResponse,
|
||||||
|
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||||
|
"""处理函数工具调用。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||||
|
"""
|
||||||
|
# function calling
|
||||||
|
tool_call_result: list[ToolCallMessageSegment] = []
|
||||||
|
logger.info(
|
||||||
|
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||||
|
)
|
||||||
|
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||||
|
llm_response.tools_call_name,
|
||||||
|
llm_response.tools_call_args,
|
||||||
|
llm_response.tools_call_ids,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
func_tool = req.func_tool.get_func(func_tool_name)
|
||||||
|
if func_tool.origin == "mcp":
|
||||||
|
logger.info(
|
||||||
|
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||||
)
|
)
|
||||||
)
|
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||||
elif llm_response.role == "tool":
|
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||||
# function calling
|
if res:
|
||||||
function_calling_result = {}
|
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
||||||
logger.info(
|
if isinstance(res.content[0], TextContent):
|
||||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
tool_call_result.append(
|
||||||
)
|
ToolCallMessageSegment(
|
||||||
for func_tool_name, func_tool_args in zip(
|
role="tool",
|
||||||
llm_response.tools_call_name, llm_response.tools_call_args
|
tool_call_id=func_tool_id,
|
||||||
):
|
content=res.content[0].text,
|
||||||
func_tool = req.func_tool.get_func(func_tool_name)
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(res.content[0], ImageContent):
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content="返回了图片(已直接发送给用户)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
event.set_extra(
|
||||||
|
"tool_call_img_respond",
|
||||||
|
res.content[0].data,
|
||||||
|
)
|
||||||
|
elif isinstance(res.content[0], EmbeddedResource):
|
||||||
|
resource = res.content[0].resource
|
||||||
|
if isinstance(resource, TextResourceContents):
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content=resource.text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(resource, BlobResourceContents)
|
||||||
|
and resource.mimeType
|
||||||
|
and resource.mimeType.startswith("image/")
|
||||||
|
):
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content="返回了图片(已直接发送给用户)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
event.set_extra(
|
||||||
|
"tool_call_img_respond",
|
||||||
|
res.content[0].data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_call_result.append(
|
||||||
|
ToolCallMessageSegment(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=func_tool_id,
|
||||||
|
content="返回的数据类型不受支持",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 获取处理器,过滤掉平台不兼容的处理器
|
||||||
|
platform_id = event.get_platform_id()
|
||||||
|
star_md = star_map.get(func_tool.handler_module_path)
|
||||||
|
if (
|
||||||
|
star_md
|
||||||
|
and platform_id in star_md.supported_platforms
|
||||||
|
and not star_md.supported_platforms[platform_id]
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||||
|
)
|
||||||
|
# 直接跳过,不添加任何消息到tool_call_result
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||||
)
|
)
|
||||||
try:
|
# 尝试调用工具函数
|
||||||
# 尝试调用工具函数
|
wrapper = self._call_handler(
|
||||||
wrapper = self._call_handler(
|
self.ctx, event, func_tool.handler, **func_tool_args
|
||||||
self.ctx, event, func_tool.handler, **func_tool_args
|
)
|
||||||
)
|
async for resp in wrapper:
|
||||||
async for resp in wrapper:
|
if resp is not None: # 有 return 返回
|
||||||
if resp is not None: # 有 return 返回
|
tool_call_result.append(
|
||||||
function_calling_result[func_tool_name] = resp
|
ToolCallMessageSegment(
|
||||||
else:
|
role="tool",
|
||||||
yield # 有生成器返回
|
tool_call_id=func_tool_id,
|
||||||
event.clear_result() # 清除上一个 handler 的结果
|
content=resp,
|
||||||
except BaseException as e:
|
)
|
||||||
logger.warning(traceback.format_exc())
|
)
|
||||||
function_calling_result[func_tool_name] = (
|
else:
|
||||||
"When calling the function, an error occurred: " + str(e)
|
res = event.get_result()
|
||||||
)
|
if res and res.chain:
|
||||||
if function_calling_result:
|
event.set_extra("tool_call_result", res)
|
||||||
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
|
yield # 有生成器返回
|
||||||
# 我们重新执行一遍这个 stage
|
event.clear_result() # 清除上一个 handler 的结果
|
||||||
req.func_tool = None # 暂时不支持递归工具调用
|
except BaseException as e:
|
||||||
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
|
logger.warning(traceback.format_exc())
|
||||||
for tool_name, tool_result in function_calling_result.items():
|
tool_call_result.append(
|
||||||
extra_prompt += (
|
ToolCallMessageSegment(
|
||||||
f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
role="tool",
|
||||||
)
|
tool_call_id=func_tool_id,
|
||||||
req.prompt += extra_prompt
|
content=f"error: {str(e)}",
|
||||||
async for _ in self.process(event, _nested=True):
|
)
|
||||||
yield
|
|
||||||
else:
|
|
||||||
if llm_response.completion_text:
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(llm_response.completion_text)
|
|
||||||
)
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
event.set_result(
|
|
||||||
MessageEventResult().message(
|
|
||||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
if tool_call_result:
|
||||||
|
# 函数调用结果
|
||||||
|
req.func_tool = None # 暂时不支持递归工具调用
|
||||||
|
assistant_msg_seg = AssistantMessageSegment(
|
||||||
|
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||||
)
|
)
|
||||||
return
|
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||||
|
req.tool_calls_result = ToolCallsResult(
|
||||||
|
tool_calls_info=assistant_msg_seg,
|
||||||
|
tool_calls_result=tool_call_result,
|
||||||
|
)
|
||||||
|
yield req # 再次执行 LLM 请求
|
||||||
|
else:
|
||||||
|
if llm_response.completion_text:
|
||||||
|
event.set_result(
|
||||||
|
MessageEventResult().message(llm_response.completion_text)
|
||||||
|
)
|
||||||
|
|
||||||
async def _save_to_history(
|
async def _save_to_history(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||||
@@ -223,9 +526,23 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
if llm_response.role == "assistant":
|
if llm_response.role == "assistant":
|
||||||
# 文本回复
|
# 文本回复
|
||||||
contexts = req.contexts
|
contexts = req.contexts.copy()
|
||||||
new_record = {"role": "user", "content": req.prompt}
|
contexts.append(await req.assemble_context())
|
||||||
contexts.append(new_record)
|
|
||||||
|
# 记录并标记函数调用结果
|
||||||
|
if req.tool_calls_result:
|
||||||
|
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||||
|
|
||||||
|
# 添加标记
|
||||||
|
for message in tool_calls_messages:
|
||||||
|
message["_tool_call_history"] = True
|
||||||
|
|
||||||
|
processed_tool_messages = self._process_tool_message_pairs(
|
||||||
|
tool_calls_messages, remove_tags=False
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.extend(processed_tool_messages)
|
||||||
|
|
||||||
contexts.append(
|
contexts.append(
|
||||||
{"role": "assistant", "content": llm_response.completion_text}
|
{"role": "assistant", "content": llm_response.completion_text}
|
||||||
)
|
)
|
||||||
@@ -235,3 +552,59 @@ class LLMRequestSubStage(Stage):
|
|||||||
await self.conv_manager.update_conversation(
|
await self.conv_manager.update_conversation(
|
||||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||||
|
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): 消息列表
|
||||||
|
remove_tags (bool): 是否移除_tool_call_history标记
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(messages):
|
||||||
|
current_msg = messages[i]
|
||||||
|
|
||||||
|
# 普通消息直接添加
|
||||||
|
if "_tool_call_history" not in current_msg:
|
||||||
|
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 工具调用消息成对处理
|
||||||
|
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||||
|
assistant_msg = current_msg.copy()
|
||||||
|
|
||||||
|
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||||
|
del assistant_msg["_tool_call_history"]
|
||||||
|
|
||||||
|
related_tools = []
|
||||||
|
j = i + 1
|
||||||
|
while (
|
||||||
|
j < len(messages)
|
||||||
|
and messages[j].get("role") == "tool"
|
||||||
|
and "_tool_call_history" in messages[j]
|
||||||
|
):
|
||||||
|
tool_msg = messages[j].copy()
|
||||||
|
|
||||||
|
if remove_tags:
|
||||||
|
del tool_msg["_tool_call_history"]
|
||||||
|
|
||||||
|
related_tools.append(tool_msg)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# 成对的时候添加到结果
|
||||||
|
if related_tools:
|
||||||
|
result.append(assistant_msg)
|
||||||
|
result.extend(related_tools)
|
||||||
|
|
||||||
|
i = j # 跳过已处理
|
||||||
|
else:
|
||||||
|
# 单独的tool消息
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
if not handlers_parsed_params:
|
if not handlers_parsed_params:
|
||||||
handlers_parsed_params = {}
|
handlers_parsed_params = {}
|
||||||
|
|
||||||
for handler in activated_handlers:
|
for handler in activated_handlers:
|
||||||
|
# 检查处理器是否在当前平台兼容
|
||||||
|
if (
|
||||||
|
hasattr(handler, "platform_compatible")
|
||||||
|
and handler.platform_compatible is False
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||||
try:
|
try:
|
||||||
if handler.handler_module_path not in star_map:
|
if handler.handler_module_path not in star_map:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage
|
|||||||
from .method.star_request import StarRequestSubStage
|
from .method.star_request import StarRequestSubStage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -58,33 +58,30 @@ class RateLimitStage(Stage):
|
|||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
||||||
timestamps = self.event_timestamps[session_id]
|
# 检查并处理限流,可能需要多次检查直到满足条件
|
||||||
|
while True:
|
||||||
|
timestamps = self.event_timestamps[session_id]
|
||||||
|
self._remove_expired_timestamps(timestamps, now)
|
||||||
|
|
||||||
self._remove_expired_timestamps(timestamps, now)
|
if len(timestamps) < self.rate_limit_count:
|
||||||
|
timestamps.append(now)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
next_window_time = timestamps[0] + self.rate_limit_time
|
||||||
|
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
||||||
|
|
||||||
if len(timestamps) >= self.rate_limit_count:
|
match self.rl_strategy:
|
||||||
# 达到限流阈值,计算下一个窗口的时间
|
case RateLimitStrategy.STALL.value:
|
||||||
next_window_time = timestamps[0] + self.rate_limit_time
|
logger.info(
|
||||||
stall_duration = (next_window_time - now).total_seconds()
|
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||||
|
)
|
||||||
match self.rl_strategy:
|
await asyncio.sleep(stall_duration)
|
||||||
case RateLimitStrategy.STALL.value:
|
now = datetime.now()
|
||||||
logger.info(
|
case RateLimitStrategy.DISCARD.value:
|
||||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
logger.info(
|
||||||
)
|
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||||
await asyncio.sleep(stall_duration)
|
)
|
||||||
case RateLimitStrategy.DISCARD.value:
|
return event.stop_event()
|
||||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
|
||||||
logger.info(
|
|
||||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
|
||||||
)
|
|
||||||
return event.stop_event()
|
|
||||||
|
|
||||||
self._remove_expired_timestamps(
|
|
||||||
timestamps, now + timedelta(seconds=stall_duration)
|
|
||||||
)
|
|
||||||
|
|
||||||
timestamps.append(now)
|
|
||||||
|
|
||||||
def _remove_expired_timestamps(
|
def _remove_expired_timestamps(
|
||||||
self, timestamps: Deque[datetime], now: datetime
|
self, timestamps: Deque[datetime], now: datetime
|
||||||
|
|||||||
@@ -2,22 +2,42 @@ import random
|
|||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import traceback
|
import traceback
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage, Stage
|
from ..stage import register_stage, Stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.message.components import Plain, Reply, At
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class RespondStage(Stage):
|
class RespondStage(Stage):
|
||||||
|
# 组件类型到其非空判断函数的映射
|
||||||
|
_component_validators = {
|
||||||
|
Comp.Plain: lambda comp: bool(
|
||||||
|
comp.text and comp.text.strip()
|
||||||
|
), # 纯文本消息需要strip
|
||||||
|
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
||||||
|
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||||
|
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||||
|
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||||
|
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||||
|
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||||
|
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||||
|
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||||
|
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||||
|
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||||
|
}
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext):
|
async def initialize(self, ctx: PipelineContext):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
self.config = ctx.astrbot_config
|
||||||
|
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||||
|
|
||||||
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||||
"reply_with_mention"
|
"reply_with_mention"
|
||||||
@@ -62,7 +82,7 @@ class RespondStage(Stage):
|
|||||||
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
||||||
"""分段回复 计算间隔时间"""
|
"""分段回复 计算间隔时间"""
|
||||||
if self.interval_method == "log":
|
if self.interval_method == "log":
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Comp.Plain):
|
||||||
wc = await self._word_cnt(comp.text)
|
wc = await self._word_cnt(comp.text)
|
||||||
i = math.log(wc + 1, self.log_base)
|
i = math.log(wc + 1, self.log_base)
|
||||||
return random.uniform(i, i + 0.5)
|
return random.uniform(i, i + 0.5)
|
||||||
@@ -72,15 +92,70 @@ class RespondStage(Stage):
|
|||||||
# random
|
# random
|
||||||
return random.uniform(self.interval[0], self.interval[1])
|
return random.uniform(self.interval[0], self.interval[1])
|
||||||
|
|
||||||
|
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
||||||
|
"""检查消息链是否为空
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain (list[BaseMessageComponent]): 包含消息对象的列表
|
||||||
|
"""
|
||||||
|
if not chain:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for comp in chain:
|
||||||
|
comp_type = type(comp)
|
||||||
|
|
||||||
|
# 检查组件类型是否在字典中
|
||||||
|
if comp_type in self._component_validators:
|
||||||
|
if self._component_validators[comp_type](comp):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 如果所有组件都为空
|
||||||
|
return True
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
|
return
|
||||||
|
|
||||||
if len(result.chain) > 0:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
# 流式结果直接交付平台适配器处理
|
||||||
|
use_fallback = self.config.get("provider_settings", {}).get(
|
||||||
|
"streaming_segmented", False
|
||||||
|
)
|
||||||
|
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||||
await event._pre_send()
|
await event._pre_send()
|
||||||
|
await event.send_streaming(result.async_stream, use_fallback)
|
||||||
|
await event._post_send()
|
||||||
|
return
|
||||||
|
elif len(result.chain) > 0:
|
||||||
|
# 检查路径映射
|
||||||
|
if mappings := self.platform_settings.get("path_mapping", []):
|
||||||
|
for idx, component in enumerate(result.chain):
|
||||||
|
if isinstance(component, Comp.File) and component.file:
|
||||||
|
# 支持 File 消息段的路径映射。
|
||||||
|
component.file = path_Mapping(mappings, component.file)
|
||||||
|
event.get_result().chain[idx] = component
|
||||||
|
|
||||||
|
await event._pre_send()
|
||||||
|
|
||||||
|
# 检查消息链是否为空
|
||||||
|
try:
|
||||||
|
if await self._is_empty_message_chain(result.chain):
|
||||||
|
logger.info("消息为空,跳过发送阶段")
|
||||||
|
event.clear_result()
|
||||||
|
event.stop_event()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"空内容检查异常: {e}")
|
||||||
|
|
||||||
|
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||||
|
non_record_comps = [
|
||||||
|
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||||
|
]
|
||||||
|
|
||||||
if self.enable_seg and (
|
if self.enable_seg and (
|
||||||
(self.only_llm_result and result.is_llm_result())
|
(self.only_llm_result and result.is_llm_result())
|
||||||
@@ -89,30 +164,55 @@ class RespondStage(Stage):
|
|||||||
decorated_comps = []
|
decorated_comps = []
|
||||||
if self.reply_with_mention:
|
if self.reply_with_mention:
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, At):
|
if isinstance(comp, Comp.At):
|
||||||
decorated_comps.append(comp)
|
decorated_comps.append(comp)
|
||||||
result.chain.remove(comp)
|
result.chain.remove(comp)
|
||||||
break
|
break
|
||||||
if self.reply_with_quote:
|
if self.reply_with_quote:
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Reply):
|
if isinstance(comp, Comp.Reply):
|
||||||
decorated_comps.append(comp)
|
decorated_comps.append(comp)
|
||||||
result.chain.remove(comp)
|
result.chain.remove(comp)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
for rcomp in record_comps:
|
||||||
|
i = await self._calc_comp_interval(rcomp)
|
||||||
|
await asyncio.sleep(i)
|
||||||
|
try:
|
||||||
|
await event.send(MessageChain([rcomp]))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
break
|
||||||
|
|
||||||
# 分段回复
|
# 分段回复
|
||||||
for comp in result.chain:
|
for comp in non_record_comps:
|
||||||
i = await self._calc_comp_interval(comp)
|
i = await self._calc_comp_interval(comp)
|
||||||
await asyncio.sleep(i)
|
await asyncio.sleep(i)
|
||||||
await event.send(MessageChain([*decorated_comps, comp]))
|
try:
|
||||||
|
await event.send(MessageChain([*decorated_comps, comp]))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
await event.send(result)
|
for rcomp in record_comps:
|
||||||
|
try:
|
||||||
|
await event.send(MessageChain([rcomp]))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await event.send(MessageChain(non_record_comps))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
|
||||||
await event._post_send()
|
await event._post_send()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnAfterMessageSentEvent
|
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
import time
|
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union, AsyncGenerator
|
from typing import AsyncGenerator, Union
|
||||||
from ..stage import Stage, register_stage, registered_stages
|
|
||||||
from ..context import PipelineContext
|
from astrbot.core import html_renderer, logger, file_token_service
|
||||||
|
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||||
|
from astrbot.core.message.message_event_result import ResultContentType
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from astrbot.core import logger
|
|
||||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
|
||||||
from astrbot.core import html_renderer
|
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||||
|
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from ..stage import Stage, register_stage, registered_stages
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
@@ -31,6 +33,8 @@ class ResultDecorateStage(Stage):
|
|||||||
self.t2i_word_threshold = 50
|
self.t2i_word_threshold = 50
|
||||||
except BaseException:
|
except BaseException:
|
||||||
self.t2i_word_threshold = 150
|
self.t2i_word_threshold = 150
|
||||||
|
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
||||||
|
self.t2i_use_network = self.t2i_strategy == "remote"
|
||||||
|
|
||||||
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
||||||
"forward_threshold"
|
"forward_threshold"
|
||||||
@@ -70,11 +74,17 @@ class ResultDecorateStage(Stage):
|
|||||||
if result is None or not result.chain:
|
if result is None or not result.chain:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
|
||||||
|
|
||||||
# 回复时检查内容安全
|
# 回复时检查内容安全
|
||||||
if (
|
if (
|
||||||
self.content_safe_check_reply
|
self.content_safe_check_reply
|
||||||
and self.content_safe_check_stage
|
and self.content_safe_check_stage
|
||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
|
and not is_stream # 流式输出不检查内容安全
|
||||||
):
|
):
|
||||||
text = ""
|
text = ""
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
@@ -87,13 +97,17 @@ class ResultDecorateStage(Stage):
|
|||||||
|
|
||||||
# 发送消息前事件钩子
|
# 发送消息前事件钩子
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnDecoratingResultEvent
|
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
)
|
)
|
||||||
|
if is_stream:
|
||||||
|
logger.warning(
|
||||||
|
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
|
||||||
|
)
|
||||||
await handler.handler(event)
|
await handler.handler(event)
|
||||||
if event.get_result() is None or not event.get_result().chain:
|
if event.get_result() is None or not event.get_result().chain:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -108,6 +122,11 @@ class ResultDecorateStage(Stage):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 流式输出不执行下面的逻辑
|
||||||
|
if is_stream:
|
||||||
|
logger.info("流式输出已启用,跳过结果装饰阶段")
|
||||||
|
return
|
||||||
|
|
||||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
@@ -133,9 +152,9 @@ class ResultDecorateStage(Stage):
|
|||||||
# 不分段回复
|
# 不分段回复
|
||||||
new_chain.append(comp)
|
new_chain.append(comp)
|
||||||
continue
|
continue
|
||||||
split_response = []
|
split_response = re.findall(
|
||||||
for line in comp.text.split("\n"):
|
self.regex, comp.text, re.DOTALL | re.MULTILINE
|
||||||
split_response.extend(re.findall(self.regex, line))
|
)
|
||||||
if not split_response:
|
if not split_response:
|
||||||
new_chain.append(comp)
|
new_chain.append(comp)
|
||||||
continue
|
continue
|
||||||
@@ -150,28 +169,55 @@ class ResultDecorateStage(Stage):
|
|||||||
result.chain = new_chain
|
result.chain = new_chain
|
||||||
|
|
||||||
# TTS
|
# TTS
|
||||||
|
tts_provider = (
|
||||||
|
self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
|
and tts_provider
|
||||||
):
|
):
|
||||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
|
||||||
new_chain = []
|
new_chain = []
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||||
try:
|
try:
|
||||||
logger.info("TTS 请求: " + comp.text)
|
logger.info(f"TTS 请求: {comp.text}")
|
||||||
audio_path = await tts_provider.get_audio(comp.text)
|
audio_path = await tts_provider.get_audio(comp.text)
|
||||||
logger.info("TTS 结果: " + audio_path)
|
logger.info(f"TTS 结果: {audio_path}")
|
||||||
if audio_path:
|
if not audio_path:
|
||||||
new_chain.append(
|
|
||||||
Record(file=audio_path, url=audio_path)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||||
)
|
)
|
||||||
new_chain.append(comp)
|
new_chain.append(comp)
|
||||||
except BaseException:
|
continue
|
||||||
|
|
||||||
|
use_file_service = self.ctx.astrbot_config[
|
||||||
|
"provider_tts_settings"
|
||||||
|
]["use_file_service"]
|
||||||
|
callback_api_base = self.ctx.astrbot_config[
|
||||||
|
"callback_api_base"
|
||||||
|
]
|
||||||
|
dual_output = self.ctx.astrbot_config[
|
||||||
|
"provider_tts_settings"
|
||||||
|
]["dual_output"]
|
||||||
|
|
||||||
|
url = None
|
||||||
|
if use_file_service and callback_api_base:
|
||||||
|
token = await file_token_service.register_file(
|
||||||
|
audio_path
|
||||||
|
)
|
||||||
|
url = f"{callback_api_base}/api/file/{token}"
|
||||||
|
logger.debug(f"已注册:{url}")
|
||||||
|
|
||||||
|
new_chain.append(
|
||||||
|
Record(
|
||||||
|
file=url or audio_path,
|
||||||
|
url=url or audio_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if dual_output:
|
||||||
|
new_chain.append(comp)
|
||||||
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error("TTS 失败,使用文本发送。")
|
logger.error("TTS 失败,使用文本发送。")
|
||||||
new_chain.append(comp)
|
new_chain.append(comp)
|
||||||
@@ -192,7 +238,9 @@ class ResultDecorateStage(Stage):
|
|||||||
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
if plain_str and len(plain_str) > self.t2i_word_threshold:
|
||||||
render_start = time.time()
|
render_start = time.time()
|
||||||
try:
|
try:
|
||||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
url = await html_renderer.render_t2i(
|
||||||
|
plain_str, return_url=True, use_network=self.t2i_use_network
|
||||||
|
)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error("文本转图片失败,使用文本发送。")
|
logger.error("文本转图片失败,使用文本发送。")
|
||||||
return
|
return
|
||||||
@@ -201,7 +249,18 @@ class ResultDecorateStage(Stage):
|
|||||||
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
|
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
|
||||||
)
|
)
|
||||||
if url:
|
if url:
|
||||||
result.chain = [Image.fromURL(url)]
|
if url.startswith("http"):
|
||||||
|
result.chain = [Image.fromURL(url)]
|
||||||
|
elif (
|
||||||
|
self.ctx.astrbot_config["t2i_use_file_service"]
|
||||||
|
and self.ctx.astrbot_config["callback_api_base"]
|
||||||
|
):
|
||||||
|
token = await file_token_service.register_file(url)
|
||||||
|
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
|
||||||
|
logger.debug(f"已注册:{url}")
|
||||||
|
result.chain = [Image.fromURL(url)]
|
||||||
|
else:
|
||||||
|
result.chain = [Image.fromFileSystem(url)]
|
||||||
|
|
||||||
# 触发转发消息
|
# 触发转发消息
|
||||||
has_forwarded = False
|
has_forwarded = False
|
||||||
|
|||||||
@@ -7,49 +7,72 @@ from astrbot.core import logger
|
|||||||
|
|
||||||
|
|
||||||
class PipelineScheduler:
|
class PipelineScheduler:
|
||||||
|
"""管道调度器,负责调度各个阶段的执行"""
|
||||||
|
|
||||||
def __init__(self, context: PipelineContext):
|
def __init__(self, context: PipelineContext):
|
||||||
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__))
|
registered_stages.sort(
|
||||||
self.ctx = context
|
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||||
|
) # 按照顺序排序
|
||||||
|
self.ctx = context # 上下文对象
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
"""初始化管道调度器时, 初始化所有阶段"""
|
||||||
for stage in registered_stages:
|
for stage in registered_stages:
|
||||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||||
|
|
||||||
await stage.initialize(self.ctx)
|
await stage.initialize(self.ctx)
|
||||||
|
|
||||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||||
|
"""依次执行各个阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||||
|
"""
|
||||||
for i in range(from_stage, len(registered_stages)):
|
for i in range(from_stage, len(registered_stages)):
|
||||||
stage = registered_stages[i]
|
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||||
coro = stage.process(event)
|
coroutine = stage.process(
|
||||||
if isinstance(coro, AsyncGenerator):
|
event
|
||||||
async for _ in coro:
|
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||||
|
|
||||||
|
if isinstance(coroutine, AsyncGenerator):
|
||||||
|
# 如果返回的是异步生成器, 实现洋葱模型的核心
|
||||||
|
async for _ in coroutine:
|
||||||
|
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 递归调用, 处理所有后续阶段
|
||||||
await self._process_stages(event, i + 1)
|
await self._process_stages(event, i + 1)
|
||||||
|
|
||||||
|
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
await coro
|
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
|
||||||
|
# 简单地等待它执行完成, 然后继续执行下一个阶段
|
||||||
|
await coroutine
|
||||||
|
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
||||||
break
|
break
|
||||||
|
|
||||||
if event.is_stopped():
|
|
||||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
|
||||||
break
|
|
||||||
|
|
||||||
async def execute(self, event: AstrMessageEvent):
|
async def execute(self, event: AstrMessageEvent):
|
||||||
"""执行 pipeline"""
|
"""执行 pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象
|
||||||
|
"""
|
||||||
await self._process_stages(event)
|
await self._process_stages(event)
|
||||||
|
|
||||||
|
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||||
await event.send(None)
|
await event.send(None)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import inspect
|
import inspect
|
||||||
|
import traceback
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from typing import List, AsyncGenerator, Union, Awaitable
|
from typing import List, AsyncGenerator, Union, Awaitable
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from .context import PipelineContext
|
from .context import PipelineContext
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||||
|
|
||||||
registered_stages: List[Stage] = []
|
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||||
"""维护了所有已注册的 Stage 实现类"""
|
|
||||||
|
|
||||||
|
|
||||||
def register_stage(cls):
|
def register_stage(cls):
|
||||||
@@ -22,14 +22,24 @@ class Stage(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
"""初始化阶段"""
|
"""初始化阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""处理事件"""
|
"""处理事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (AstrMessageEvent): 事件对象,包含事件的相关信息
|
||||||
|
Returns:
|
||||||
|
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _call_handler(
|
async def _call_handler(
|
||||||
@@ -40,33 +50,61 @@ class Stage(abc.ABC):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[None, None]:
|
) -> AsyncGenerator[None, None]:
|
||||||
"""调用 Handler。"""
|
"""执行事件处理函数并处理其返回结果
|
||||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
|
||||||
ready_to_call = None
|
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||||
|
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||||
|
2. 协程: 执行一次并处理返回值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象
|
||||||
|
event (AstrMessageEvent): 待处理的事件对象
|
||||||
|
handler (Awaitable): 事件处理函数
|
||||||
|
*args: 传递给handler的位置参数
|
||||||
|
**kwargs: 传递给handler的关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||||
|
"""
|
||||||
|
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||||
|
|
||||||
|
trace_ = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ready_to_call = handler(event, *args, **kwargs)
|
ready_to_call = handler(event, *args, **kwargs)
|
||||||
except TypeError as e:
|
except TypeError as _:
|
||||||
# 向下兼容
|
# 向下兼容
|
||||||
logger.debug(str(e))
|
trace_ = traceback.format_exc()
|
||||||
|
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(ready_to_call, AsyncGenerator):
|
if isinstance(ready_to_call, AsyncGenerator):
|
||||||
_has_yielded = False
|
# 如果是一个异步生成器, 进入洋葱模型
|
||||||
async for ret in ready_to_call:
|
_has_yielded = False # 是否返回过值
|
||||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
try:
|
||||||
_has_yielded = True
|
async for ret in ready_to_call:
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||||
event.set_result(ret)
|
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||||
|
_has_yielded = True
|
||||||
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||||
|
event.set_result(ret)
|
||||||
|
yield # 传递控制权给上一层的process函数
|
||||||
|
else:
|
||||||
|
# 如果返回值是 None, 则不设置结果并继续
|
||||||
|
# 继续执行后续阶段
|
||||||
|
yield ret # 传递控制权给上一层的process函数
|
||||||
|
if not _has_yielded:
|
||||||
|
# 如果这个异步生成器没有执行到yield分支
|
||||||
yield
|
yield
|
||||||
else:
|
except Exception as e:
|
||||||
yield ret
|
logger.error(f"Previous Error: {trace_}")
|
||||||
if not _has_yielded:
|
raise e
|
||||||
yield
|
|
||||||
elif inspect.iscoroutine(ready_to_call):
|
elif inspect.iscoroutine(ready_to_call):
|
||||||
# 如果只是一个 coroutine
|
# 如果只是一个协程, 直接执行
|
||||||
ret = await ready_to_call
|
ret = await ready_to_call
|
||||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
event.set_result(ret)
|
event.set_result(ret)
|
||||||
yield
|
yield # 传递控制权给上一层的process函数
|
||||||
else:
|
else:
|
||||||
yield ret
|
yield ret # 传递控制权给上一层的process函数
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from ..stage import Stage, register_stage
|
from ..stage import Stage, register_stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
|
from astrbot import logger
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||||
@@ -21,18 +22,38 @@ class WakingCheckStage(Stage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
|
"""初始化唤醒检查阶段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||||
|
"""
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
||||||
"no_permission_reply", True
|
"no_permission_reply", True
|
||||||
)
|
)
|
||||||
|
# 私聊是否需要 wake_prefix 才能唤醒机器人
|
||||||
|
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
|
||||||
|
"platform_settings"
|
||||||
|
].get("friend_message_needs_wake_prefix", False)
|
||||||
|
# 是否忽略机器人自己发送的消息
|
||||||
|
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||||
|
"ignore_bot_self_message", False
|
||||||
|
)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
|
if (
|
||||||
|
self.ignore_bot_self_message
|
||||||
|
and event.get_self_id() == event.get_sender_id()
|
||||||
|
):
|
||||||
|
# 忽略机器人自己发送的消息
|
||||||
|
event.stop_event()
|
||||||
|
return
|
||||||
# 设置 sender 身份
|
# 设置 sender 身份
|
||||||
event.message_str = event.message_str.strip()
|
event.message_str = event.message_str.strip()
|
||||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||||
if event.get_sender_id() == admin_id:
|
if str(event.get_sender_id()) == admin_id:
|
||||||
event.role = "admin"
|
event.role = "admin"
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -68,7 +89,7 @@ class WakingCheckStage(Stage):
|
|||||||
event.is_at_or_wake_command = True
|
event.is_at_or_wake_command = True
|
||||||
break
|
break
|
||||||
# 检查是否是私聊
|
# 检查是否是私聊
|
||||||
if event.is_private_chat():
|
if event.is_private_chat() and not self.friend_message_needs_wake_prefix:
|
||||||
is_wake = True
|
is_wake = True
|
||||||
event.is_wake = True
|
event.is_wake = True
|
||||||
event.is_at_or_wake_command = True
|
event.is_at_or_wake_command = True
|
||||||
@@ -84,6 +105,7 @@ class WakingCheckStage(Stage):
|
|||||||
# filter 需满足 AND 逻辑关系
|
# filter 需满足 AND 逻辑关系
|
||||||
passed = True
|
passed = True
|
||||||
permission_not_pass = False
|
permission_not_pass = False
|
||||||
|
permission_filter_raise_error = False
|
||||||
if len(handler.event_filters) == 0:
|
if len(handler.event_filters) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -92,6 +114,7 @@ class WakingCheckStage(Stage):
|
|||||||
if isinstance(filter, PermissionTypeFilter):
|
if isinstance(filter, PermissionTypeFilter):
|
||||||
if not filter.filter(event, self.ctx.astrbot_config):
|
if not filter.filter(event, self.ctx.astrbot_config):
|
||||||
permission_not_pass = True
|
permission_not_pass = True
|
||||||
|
permission_filter_raise_error = filter.raise_error
|
||||||
else:
|
else:
|
||||||
if not filter.filter(event, self.ctx.astrbot_config):
|
if not filter.filter(event, self.ctx.astrbot_config):
|
||||||
passed = False
|
passed = False
|
||||||
@@ -102,17 +125,25 @@ class WakingCheckStage(Stage):
|
|||||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
await event._post_send()
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
passed = False
|
passed = False
|
||||||
break
|
break
|
||||||
if passed:
|
if passed:
|
||||||
if permission_not_pass:
|
if permission_not_pass:
|
||||||
|
if not permission_filter_raise_error:
|
||||||
|
# 跳过
|
||||||
|
continue
|
||||||
if self.no_permission_reply:
|
if self.no_permission_reply:
|
||||||
await event.send(
|
await event.send(
|
||||||
MessageChain().message(
|
MessageChain().message(
|
||||||
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
await event._post_send()
|
||||||
|
logger.info(
|
||||||
|
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||||
|
)
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ class WhitelistCheckStage(Stage):
|
|||||||
"enable_id_white_list"
|
"enable_id_white_list"
|
||||||
]
|
]
|
||||||
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
|
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
|
||||||
|
self.whitelist = [
|
||||||
|
str(i).strip() for i in self.whitelist if str(i).strip() != ""
|
||||||
|
]
|
||||||
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
|
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
|
||||||
"wl_ignore_admin_on_group"
|
"wl_ignore_admin_on_group"
|
||||||
]
|
]
|
||||||
@@ -51,7 +54,10 @@ class WhitelistCheckStage(Stage):
|
|||||||
and event.get_message_type() == MessageType.FRIEND_MESSAGE
|
and event.get_message_type() == MessageType.FRIEND_MESSAGE
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
if event.unified_msg_origin not in self.whitelist:
|
if (
|
||||||
|
event.unified_msg_origin not in self.whitelist
|
||||||
|
and str(event.get_group_id()).strip() not in self.whitelist
|
||||||
|
):
|
||||||
if self.wl_log:
|
if self.wl_log:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
|
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .platform import Platform
|
from .platform import Platform
|
||||||
from .astr_message_event import AstrMessageEvent
|
from .astr_message_event import AstrMessageEvent
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
from .astrbot_message import AstrBotMessage, MessageMember, MessageType
|
from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Platform",
|
"Platform",
|
||||||
@@ -10,4 +10,5 @@ __all__ = [
|
|||||||
"AstrBotMessage",
|
"AstrBotMessage",
|
||||||
"MessageMember",
|
"MessageMember",
|
||||||
"MessageType",
|
"MessageType",
|
||||||
|
"Group",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .astrbot_message import AstrBotMessage
|
from typing import List, Union, Optional, AsyncGenerator
|
||||||
from .platform_metadata import PlatformMetadata
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
from astrbot.core.db.po import Conversation
|
||||||
from astrbot.core.platform.message_type import MessageType
|
|
||||||
from typing import List, Union
|
|
||||||
from astrbot.core.message.components import (
|
from astrbot.core.message.components import (
|
||||||
Plain,
|
Plain,
|
||||||
Image,
|
Image,
|
||||||
@@ -14,10 +15,14 @@ from astrbot.core.message.components import (
|
|||||||
At,
|
At,
|
||||||
AtAll,
|
AtAll,
|
||||||
Forward,
|
Forward,
|
||||||
|
Reply,
|
||||||
)
|
)
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||||
|
from astrbot.core.platform.message_type import MessageType
|
||||||
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from astrbot.core.provider.entites import ProviderRequest
|
from .astrbot_message import AstrBotMessage, Group
|
||||||
from astrbot.core.db.po import Conversation
|
from .platform_metadata import PlatformMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -79,6 +84,9 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
def get_platform_name(self):
|
def get_platform_name(self):
|
||||||
return self.platform_meta.name
|
return self.platform_meta.name
|
||||||
|
|
||||||
|
def get_platform_id(self):
|
||||||
|
return self.platform_meta.id
|
||||||
|
|
||||||
def get_message_str(self) -> str:
|
def get_message_str(self) -> str:
|
||||||
"""
|
"""
|
||||||
获取消息字符串。
|
获取消息字符串。
|
||||||
@@ -101,8 +109,15 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
elif isinstance(i, Forward):
|
elif isinstance(i, Forward):
|
||||||
# 转发消息
|
# 转发消息
|
||||||
outline += "[转发消息]"
|
outline += "[转发消息]"
|
||||||
|
elif isinstance(i, Reply):
|
||||||
|
# 引用回复
|
||||||
|
if i.message_str:
|
||||||
|
outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
|
||||||
|
else:
|
||||||
|
outline += "[引用消息]"
|
||||||
else:
|
else:
|
||||||
outline += f"[{i.type}]"
|
outline += f"[{i.type}]"
|
||||||
|
outline += " "
|
||||||
return outline
|
return outline
|
||||||
|
|
||||||
def get_message_outline(self) -> str:
|
def get_message_outline(self) -> str:
|
||||||
@@ -193,9 +208,26 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
return self.role == "admin"
|
return self.role == "admin"
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
|
||||||
"""
|
"""
|
||||||
发送消息到消息平台。
|
将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
match = re.search(pattern, buffer)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
matched_text = match.group()
|
||||||
|
await self.send(MessageChain([Plain(matched_text)]))
|
||||||
|
buffer = buffer[match.end() :]
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||||
|
):
|
||||||
|
"""发送流式消息到消息平台,使用异步生成器。
|
||||||
|
目前仅支持: telegram,qq official 私聊。
|
||||||
|
Fallback仅支持 aiocqhttp, gewechat。
|
||||||
"""
|
"""
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||||
@@ -363,3 +395,31 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""平台适配器"""
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
"""发送消息到消息平台。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (MessageChain): 消息链,具体使用方式请参考文档。
|
||||||
|
"""
|
||||||
|
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
|
||||||
|
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
|
||||||
|
sid = str(uuid.UUID(bytes=hash_obj.digest()))
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._has_send_oper = True
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||||
|
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||||
|
|
||||||
|
适配情况:
|
||||||
|
|
||||||
|
- gewechat
|
||||||
|
- aiocqhttp(OneBotv11)
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -10,6 +10,41 @@ class MessageMember:
|
|||||||
user_id: str # 发送者id
|
user_id: str # 发送者id
|
||||||
nickname: str = None
|
nickname: str = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# 使用 f-string 来构建返回的字符串表示形式
|
||||||
|
return (
|
||||||
|
f"User ID: {self.user_id},"
|
||||||
|
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Group:
|
||||||
|
group_id: str
|
||||||
|
"""群号"""
|
||||||
|
group_name: str = None
|
||||||
|
"""群名称"""
|
||||||
|
group_avatar: str = None
|
||||||
|
"""群头像"""
|
||||||
|
group_owner: str = None
|
||||||
|
"""群主 id"""
|
||||||
|
group_admins: List[str] = None
|
||||||
|
"""群管理员 id"""
|
||||||
|
members: List[MessageMember] = None
|
||||||
|
"""所有群成员"""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# 使用 f-string 来构建返回的字符串表示形式
|
||||||
|
return (
|
||||||
|
f"Group ID: {self.group_id}\n"
|
||||||
|
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
|
||||||
|
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
|
||||||
|
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
|
||||||
|
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
|
||||||
|
f"Members Len: {len(self.members) if self.members else 0}\n"
|
||||||
|
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AstrBotMessage:
|
class AstrBotMessage:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -62,12 +62,22 @@ class PlatformManager:
|
|||||||
from .sources.gewechat.gewechat_platform_adapter import (
|
from .sources.gewechat.gewechat_platform_adapter import (
|
||||||
GewechatPlatformAdapter, # noqa: F401
|
GewechatPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
|
case "wechatpadpro":
|
||||||
|
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||||
|
WeChatPadProAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "lark":
|
case "lark":
|
||||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||||
|
case "dingtalk":
|
||||||
|
from .sources.dingtalk.dingtalk_adapter import (
|
||||||
|
DingtalkPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "telegram":
|
case "telegram":
|
||||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||||
case "wecom":
|
case "wecom":
|
||||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||||
|
case "weixin_official_account":
|
||||||
|
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||||
@@ -81,14 +91,18 @@ class PlatformManager:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
cls_type = platform_cls_map[platform_config["type"]]
|
cls_type = platform_cls_map[platform_config["type"]]
|
||||||
inst = cls_type(platform_config, self.settings, self.event_queue)
|
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
|
||||||
self._inst_map[platform_config["id"]] = inst
|
self._inst_map[platform_config["id"]] = {
|
||||||
|
"inst": inst,
|
||||||
|
"client_id": inst.client_self_id,
|
||||||
|
}
|
||||||
self.platform_insts.append(inst)
|
self.platform_insts.append(inst)
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self._task_wrapper(
|
self._task_wrapper(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
inst.run(), name=platform_config["id"] + "_platform"
|
inst.run(),
|
||||||
|
name=f"platform_{platform_config['type']}_{platform_config['id']}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -105,38 +119,42 @@ class PlatformManager:
|
|||||||
logger.error("-------")
|
logger.error("-------")
|
||||||
|
|
||||||
async def reload(self, platform_config: dict):
|
async def reload(self, platform_config: dict):
|
||||||
# 还未实现完成,不要调用此方法
|
await self.terminate_platform(platform_config["id"])
|
||||||
|
if platform_config["enable"]:
|
||||||
if platform_config["id"] in self._inst_map:
|
|
||||||
# 正在运行
|
|
||||||
if getattr(self._inst_map[platform_config["id"]], "terminate", None):
|
|
||||||
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
|
|
||||||
await self._inst_map[platform_config["id"]].terminate()
|
|
||||||
logger.info(f"{platform_config['id']} 平台适配器已终止。")
|
|
||||||
del self._inst_map[platform_config["id"]]
|
|
||||||
self.platform_insts.remove(self._inst_map[platform_config["id"]])
|
|
||||||
else:
|
|
||||||
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
|
|
||||||
|
|
||||||
# 再启动新的实例
|
|
||||||
await self.load_platform(platform_config)
|
await self.load_platform(platform_config)
|
||||||
|
|
||||||
else:
|
# 和配置文件保持同步
|
||||||
# 先将 _inst_map 中在 platform_config 中不存在的实例删除
|
config_ids = [provider["id"] for provider in self.platforms_config]
|
||||||
config_ids = [platform["id"] for platform in self.platforms_config]
|
for key in list(self._inst_map.keys()):
|
||||||
for key in list(self._inst_map.keys()):
|
if key not in config_ids:
|
||||||
if key not in config_ids:
|
await self.terminate_platform(key)
|
||||||
if getattr(self._inst_map[key], "terminate", None):
|
|
||||||
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
|
|
||||||
await self._inst_map[key].terminate()
|
|
||||||
logger.info(f"{key} 平台适配器已终止。")
|
|
||||||
del self._inst_map[key]
|
|
||||||
self.platform_insts.remove(self._inst_map[key])
|
|
||||||
else:
|
|
||||||
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
|
|
||||||
|
|
||||||
# 再启动新的实例
|
async def terminate_platform(self, platform_id: str):
|
||||||
await self.load_platform(platform_config)
|
if platform_id in self._inst_map:
|
||||||
|
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||||
|
|
||||||
|
# client_id = self._inst_map.pop(platform_id, None)
|
||||||
|
info = self._inst_map.pop(platform_id, None)
|
||||||
|
client_id = info["client_id"]
|
||||||
|
inst = info["inst"]
|
||||||
|
try:
|
||||||
|
self.platform_insts.remove(
|
||||||
|
next(
|
||||||
|
inst
|
||||||
|
for inst in self.platform_insts
|
||||||
|
if inst.client_self_id == client_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
|
||||||
|
|
||||||
|
if getattr(inst, "terminate", None):
|
||||||
|
await inst.terminate()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
for inst in self.platform_insts:
|
||||||
|
if getattr(inst, "terminate", None):
|
||||||
|
await inst.terminate()
|
||||||
|
|
||||||
def get_insts(self):
|
def get_insts(self):
|
||||||
return self.platform_insts
|
return self.platform_insts
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import uuid
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
@@ -13,6 +14,7 @@ class Platform(abc.ABC):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||||
self._event_queue = event_queue
|
self._event_queue = event_queue
|
||||||
|
self.client_self_id = uuid.uuid4().hex
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self) -> Awaitable[Any]:
|
def run(self) -> Awaitable[Any]:
|
||||||
@@ -25,7 +27,7 @@ class Platform(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
终止一个平台的运行实例。
|
终止一个平台的运行实例。
|
||||||
"""
|
"""
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ class PlatformMetadata:
|
|||||||
"""平台的名称"""
|
"""平台的名称"""
|
||||||
description: str
|
description: str
|
||||||
"""平台的描述"""
|
"""平台的描述"""
|
||||||
|
id: str = None
|
||||||
|
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||||
|
|
||||||
default_config_tmpl: dict = None
|
default_config_tmpl: dict = None
|
||||||
"""平台的默认配置模板"""
|
"""平台的默认配置模板"""
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from typing import AsyncGenerator, Dict, List
|
||||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
|
||||||
from aiocqhttp import CQHttp
|
from aiocqhttp import CQHttp
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.message_components import (
|
||||||
|
Image,
|
||||||
|
Node,
|
||||||
|
Nodes,
|
||||||
|
Plain,
|
||||||
|
Record,
|
||||||
|
Video,
|
||||||
|
File,
|
||||||
|
BaseMessageComponent,
|
||||||
|
)
|
||||||
|
from astrbot.api.platform import Group, MessageMember
|
||||||
|
|
||||||
|
|
||||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||||
@@ -13,51 +23,57 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
|
||||||
|
"""修复部分字段"""
|
||||||
|
if isinstance(segment, (Image, Record)):
|
||||||
|
# For Image and Record segments, we convert them to base64
|
||||||
|
bs64 = await segment.convert_to_base64()
|
||||||
|
return {
|
||||||
|
"type": segment.type.lower(),
|
||||||
|
"data": {
|
||||||
|
"file": f"base64://{bs64}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif isinstance(segment, File):
|
||||||
|
# For File segments, we need to handle the file differently
|
||||||
|
d = await segment.to_dict()
|
||||||
|
return d
|
||||||
|
elif isinstance(segment, Video):
|
||||||
|
d = await segment.to_dict()
|
||||||
|
return d
|
||||||
|
else:
|
||||||
|
# For other segments, we simply convert them to a dict by calling toDict
|
||||||
|
return segment.toDict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _parse_onebot_json(message_chain: MessageChain):
|
async def _parse_onebot_json(message_chain: MessageChain):
|
||||||
"""解析成 OneBot json 格式"""
|
"""解析成 OneBot json 格式"""
|
||||||
ret = []
|
ret = []
|
||||||
for segment in message_chain.chain:
|
for segment in message_chain.chain:
|
||||||
d = segment.toDict()
|
|
||||||
if isinstance(segment, Plain):
|
if isinstance(segment, Plain):
|
||||||
d["type"] = "text"
|
if not segment.text.strip():
|
||||||
elif isinstance(segment, (Image, Record)):
|
continue
|
||||||
# convert to base64
|
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||||
if segment.file and segment.file.startswith("file:///"):
|
|
||||||
bs64_data = file_to_base64(segment.file[8:])
|
|
||||||
image_file_path = segment.file[8:]
|
|
||||||
elif segment.file and segment.file.startswith("http"):
|
|
||||||
image_file_path = await download_image_by_url(segment.file)
|
|
||||||
bs64_data = file_to_base64(image_file_path)
|
|
||||||
elif segment.file and segment.file.startswith("base64://"):
|
|
||||||
bs64_data = segment.file
|
|
||||||
else:
|
|
||||||
bs64_data = file_to_base64(segment.file)
|
|
||||||
d["data"] = {
|
|
||||||
"file": bs64_data,
|
|
||||||
}
|
|
||||||
elif isinstance(segment, At):
|
|
||||||
d["data"] = {
|
|
||||||
"qq": str(segment.qq) # 转换为字符串
|
|
||||||
}
|
|
||||||
ret.append(d)
|
ret.append(d)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||||
|
send_one_by_one = any(
|
||||||
send_one_by_one = False
|
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
||||||
for seg in message.chain:
|
)
|
||||||
if isinstance(seg, (Node, Nodes)):
|
|
||||||
# 转发消息不能和普通消息混在一起发送
|
|
||||||
send_one_by_one = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if send_one_by_one:
|
if send_one_by_one:
|
||||||
for seg in message.chain:
|
for seg in message.chain:
|
||||||
if isinstance(seg, Nodes):
|
if isinstance(seg, (Node, Nodes)):
|
||||||
# 带有多个节点的合并转发消息
|
# 合并转发消息
|
||||||
payload = seg.toDict()
|
|
||||||
|
if isinstance(seg, Node):
|
||||||
|
nodes = Nodes([seg])
|
||||||
|
seg = nodes
|
||||||
|
|
||||||
|
payload = await seg.to_dict()
|
||||||
|
|
||||||
if self.get_group_id():
|
if self.get_group_id():
|
||||||
payload["group_id"] = self.get_group_id()
|
payload["group_id"] = self.get_group_id()
|
||||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||||
@@ -66,6 +82,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
await self.bot.call_action(
|
await self.bot.call_action(
|
||||||
"send_private_forward_msg", **payload
|
"send_private_forward_msg", **payload
|
||||||
)
|
)
|
||||||
|
elif isinstance(seg, File):
|
||||||
|
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||||
|
await self.bot.send(
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
[d],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await self.bot.send(
|
await self.bot.send(
|
||||||
self.message_obj.raw_message,
|
self.message_obj.raw_message,
|
||||||
@@ -75,6 +97,86 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
else:
|
else:
|
||||||
|
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||||
|
if not ret:
|
||||||
|
return
|
||||||
await self.bot.send(self.message_obj.raw_message, ret)
|
await self.bot.send(self.message_obj.raw_message, ret)
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||||
|
):
|
||||||
|
if not use_fallback:
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
buffer += comp.text
|
||||||
|
if any(p in buffer for p in "。?!~…"):
|
||||||
|
buffer = await self.process_buffer(buffer, pattern)
|
||||||
|
else:
|
||||||
|
await self.send(MessageChain(chain=[comp]))
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
|
||||||
|
if buffer.strip():
|
||||||
|
await self.send(MessageChain([Plain(buffer)]))
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
async def get_group(self, group_id=None, **kwargs):
|
||||||
|
if isinstance(group_id, str) and group_id.isdigit():
|
||||||
|
group_id = int(group_id)
|
||||||
|
elif self.get_group_id():
|
||||||
|
group_id = int(self.get_group_id())
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
info: dict = await self.bot.call_action(
|
||||||
|
"get_group_info",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
members: List[Dict] = await self.bot.call_action(
|
||||||
|
"get_group_member_list",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
owner_id = None
|
||||||
|
admin_ids = []
|
||||||
|
for member in members:
|
||||||
|
if member["role"] == "owner":
|
||||||
|
owner_id = member["user_id"]
|
||||||
|
if member["role"] == "admin":
|
||||||
|
admin_ids.append(member["user_id"])
|
||||||
|
|
||||||
|
group = Group(
|
||||||
|
group_id=str(group_id),
|
||||||
|
group_name=info.get("group_name"),
|
||||||
|
group_avatar="",
|
||||||
|
group_admins=admin_ids,
|
||||||
|
group_owner=str(owner_id),
|
||||||
|
members=[
|
||||||
|
MessageMember(
|
||||||
|
user_id=member["user_id"],
|
||||||
|
nickname=member.get("nickname") or member.get("card"),
|
||||||
|
)
|
||||||
|
for member in members
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return group
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
import itertools
|
||||||
from typing import Awaitable, Any
|
from typing import Awaitable, Any
|
||||||
from aiocqhttp import CQHttp, Event
|
from aiocqhttp import CQHttp, Event
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
@@ -20,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
|||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
from aiocqhttp.exceptions import ActionFailed
|
from aiocqhttp.exceptions import ActionFailed
|
||||||
from astrbot.core.utils.io import download_file
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter(
|
@register_platform_adapter(
|
||||||
@@ -39,14 +38,18 @@ class AiocqhttpAdapter(Platform):
|
|||||||
self.port = platform_config["ws_reverse_port"]
|
self.port = platform_config["ws_reverse_port"]
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
"aiocqhttp",
|
name="aiocqhttp",
|
||||||
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stop = False
|
|
||||||
|
|
||||||
self.bot = CQHttp(
|
self.bot = CQHttp(
|
||||||
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
use_ws_reverse=True,
|
||||||
|
import_name="aiocqhttp",
|
||||||
|
api_timeout_sec=180,
|
||||||
|
access_token=platform_config.get(
|
||||||
|
"ws_reverse_token"
|
||||||
|
), # 以防旧版本配置不存在
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.bot.on_request()
|
@self.bot.on_request()
|
||||||
@@ -100,6 +103,9 @@ class AiocqhttpAdapter(Platform):
|
|||||||
|
|
||||||
if event["post_type"] == "message":
|
if event["post_type"] == "message":
|
||||||
abm = await self._convert_handle_message_event(event)
|
abm = await self._convert_handle_message_event(event)
|
||||||
|
if abm.sender.user_id == "2854196310":
|
||||||
|
# 屏蔽 QQ 管家的消息
|
||||||
|
return
|
||||||
elif event["post_type"] == "notice":
|
elif event["post_type"] == "notice":
|
||||||
abm = await self._convert_handle_notice_event(event)
|
abm = await self._convert_handle_notice_event(event)
|
||||||
elif event["post_type"] == "request":
|
elif event["post_type"] == "request":
|
||||||
@@ -111,7 +117,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
"""OneBot V11 请求类事件"""
|
"""OneBot V11 请求类事件"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||||
abm.type = MessageType.OTHER_MESSAGE
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
if "group_id" in event and event["group_id"]:
|
if "group_id" in event and event["group_id"]:
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
@@ -120,6 +126,12 @@ class AiocqhttpAdapter(Platform):
|
|||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||||
|
else:
|
||||||
|
abm.session_id = (
|
||||||
|
str(event.group_id)
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE
|
||||||
|
else abm.sender.user_id
|
||||||
|
)
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
abm.message = []
|
abm.message = []
|
||||||
abm.timestamp = int(time.time())
|
abm.timestamp = int(time.time())
|
||||||
@@ -131,7 +143,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
"""OneBot V11 通知类事件"""
|
"""OneBot V11 通知类事件"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||||
abm.type = MessageType.OTHER_MESSAGE
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
if "group_id" in event and event["group_id"]:
|
if "group_id" in event and event["group_id"]:
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
@@ -140,7 +152,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
abm.session_id = (
|
abm.session_id = (
|
||||||
abm.sender.user_id + "_" + str(event.group_id)
|
str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||||
) # 也保留群组 id
|
) # 也保留群组 id
|
||||||
else:
|
else:
|
||||||
abm.session_id = (
|
abm.session_id = (
|
||||||
@@ -156,12 +168,20 @@ class AiocqhttpAdapter(Platform):
|
|||||||
|
|
||||||
if "sub_type" in event:
|
if "sub_type" in event:
|
||||||
if event["sub_type"] == "poke" and "target_id" in event:
|
if event["sub_type"] == "poke" and "target_id" in event:
|
||||||
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
abm.message.append(
|
||||||
|
Poke(qq=str(event["target_id"]), type="poke")
|
||||||
|
) # noqa: F405
|
||||||
|
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
|
async def _convert_handle_message_event(
|
||||||
"""OneBot V11 消息类事件"""
|
self, event: Event, get_reply=True
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""OneBot V11 消息类事件
|
||||||
|
|
||||||
|
@param event: 事件对象
|
||||||
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
|
"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
@@ -197,52 +217,119 @@ class AiocqhttpAdapter(Platform):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 按消息段类型类型适配
|
# 按消息段类型类型适配
|
||||||
for m in event.message:
|
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||||
t = m["type"]
|
|
||||||
a = None
|
a = None
|
||||||
if t == "text":
|
if t == "text":
|
||||||
message_str += m["data"]["text"].strip()
|
current_text = "".join(m["data"]["text"] for m in m_group).strip()
|
||||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
message_str += current_text
|
||||||
|
a = ComponentTypes[t](text=current_text) # noqa: F405
|
||||||
abm.message.append(a)
|
abm.message.append(a)
|
||||||
|
|
||||||
elif t == "file":
|
elif t == "file":
|
||||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
for m in m_group:
|
||||||
# Lagrange
|
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||||
logger.info("guessing lagrange")
|
# Lagrange
|
||||||
|
logger.info("guessing lagrange")
|
||||||
|
file_name = m["data"].get("file_name", "file")
|
||||||
|
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Napcat
|
||||||
|
ret = None
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
ret = await self.bot.call_action(
|
||||||
|
action="get_group_file_url",
|
||||||
|
file_id=event.message[0]["data"]["file_id"],
|
||||||
|
group_id=event.group_id,
|
||||||
|
)
|
||||||
|
elif abm.type == MessageType.FRIEND_MESSAGE:
|
||||||
|
ret = await self.bot.call_action(
|
||||||
|
action="get_private_file_url",
|
||||||
|
file_id=event.message[0]["data"]["file_id"],
|
||||||
|
)
|
||||||
|
if ret and "url" in ret:
|
||||||
|
file_url = ret["url"] # https
|
||||||
|
a = File(name="", url=file_url)
|
||||||
|
abm.message.append(a)
|
||||||
|
else:
|
||||||
|
logger.error(f"获取文件失败: {ret}")
|
||||||
|
|
||||||
file_name = m["data"].get("file_name", "file")
|
except ActionFailed as e:
|
||||||
path = os.path.join("data/temp", file_name)
|
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||||
await download_file(m["data"]["url"], path)
|
except BaseException as e:
|
||||||
|
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||||
|
|
||||||
m["data"] = {"file": path, "name": file_name}
|
elif t == "reply":
|
||||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
for m in m_group:
|
||||||
abm.message.append(a)
|
if not get_reply:
|
||||||
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# Napcat, LLBot
|
|
||||||
ret = await self.bot.call_action(
|
|
||||||
action="get_file",
|
|
||||||
file_id=event.message[0]["data"]["file_id"],
|
|
||||||
)
|
|
||||||
if not ret.get("file", None):
|
|
||||||
raise ValueError(f"无法解析文件响应: {ret}")
|
|
||||||
if not os.path.exists(ret["file"]):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
|
|
||||||
)
|
|
||||||
|
|
||||||
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
|
|
||||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
abm.message.append(a)
|
abm.message.append(a)
|
||||||
except ActionFailed as e:
|
else:
|
||||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
try:
|
||||||
except BaseException as e:
|
reply_event_data = await self.bot.call_action(
|
||||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
action="get_msg",
|
||||||
|
message_id=int(m["data"]["id"]),
|
||||||
|
)
|
||||||
|
abm_reply = await self._convert_handle_message_event(
|
||||||
|
Event.from_payload(reply_event_data), get_reply=False
|
||||||
|
)
|
||||||
|
|
||||||
|
reply_seg = Reply(
|
||||||
|
id=abm_reply.message_id,
|
||||||
|
chain=abm_reply.message,
|
||||||
|
sender_id=abm_reply.sender.user_id,
|
||||||
|
sender_nickname=abm_reply.sender.nickname,
|
||||||
|
time=abm_reply.timestamp,
|
||||||
|
message_str=abm_reply.message_str,
|
||||||
|
text=abm_reply.message_str, # for compatibility
|
||||||
|
qq=abm_reply.sender.user_id, # for compatibility
|
||||||
|
)
|
||||||
|
|
||||||
|
abm.message.append(reply_seg)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"获取引用消息失败: {e}。")
|
||||||
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
|
abm.message.append(a)
|
||||||
|
elif t == "at":
|
||||||
|
first_at_self_processed = False
|
||||||
|
|
||||||
|
for m in m_group:
|
||||||
|
try:
|
||||||
|
if m["data"]["qq"] == "all":
|
||||||
|
abm.message.append(At(qq="all", name="全体成员"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
at_info = await self.bot.call_action(
|
||||||
|
action="get_stranger_info",
|
||||||
|
user_id=int(m["data"]["qq"]),
|
||||||
|
)
|
||||||
|
if at_info:
|
||||||
|
nickname = at_info.get("nick", "")
|
||||||
|
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||||
|
|
||||||
|
abm.message.append(
|
||||||
|
At(
|
||||||
|
qq=m["data"]["qq"],
|
||||||
|
name=nickname,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_at_self and not first_at_self_processed:
|
||||||
|
# 第一个@是机器人,不添加到message_str
|
||||||
|
first_at_self_processed = True
|
||||||
|
else:
|
||||||
|
# 非第一个@机器人或@其他用户,添加到message_str
|
||||||
|
message_str += f" @{nickname} "
|
||||||
|
else:
|
||||||
|
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||||
|
except ActionFailed as e:
|
||||||
|
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||||
else:
|
else:
|
||||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
for m in m_group:
|
||||||
abm.message.append(a)
|
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||||
|
abm.message.append(a)
|
||||||
|
|
||||||
abm.timestamp = int(time.time())
|
abm.timestamp = int(time.time())
|
||||||
abm.message_str = message_str
|
abm.message_str = message_str
|
||||||
@@ -267,22 +354,19 @@ class AiocqhttpAdapter(Platform):
|
|||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
|
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
return coro
|
return coro
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
self.stop = True
|
self.shutdown_event.set()
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
async def shutdown_trigger_placeholder(self):
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
logger.info("aiocqhttp 适配器已被优雅地关闭")
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return self.metadata
|
return self.metadata
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
|
||||||
# TODO: use asyncio.Event
|
|
||||||
while not self._event_queue.closed and not self.stop: # noqa: ASYNC110
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("aiocqhttp 适配器已关闭。")
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
message_event = AiocqhttpMessageEvent(
|
message_event = AiocqhttpMessageEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
|
|||||||
231
astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Normal file
231
astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import aiohttp
|
||||||
|
import dingtalk_stream
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Image, Plain, At
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from .dingtalk_event import DingtalkMessageEvent
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from astrbot import logger
|
||||||
|
from dingtalk_stream import AckMessage
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
|
||||||
|
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||||
|
async def process(self, event: dingtalk_stream.EventMessage):
|
||||||
|
print(
|
||||||
|
"2",
|
||||||
|
event.headers.event_type,
|
||||||
|
event.headers.event_id,
|
||||||
|
event.headers.event_born_time,
|
||||||
|
event.data,
|
||||||
|
)
|
||||||
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
|
||||||
|
class DingtalkPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
|
||||||
|
self.config = platform_config
|
||||||
|
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
self.client_id = platform_config["client_id"]
|
||||||
|
self.client_secret = platform_config["client_secret"]
|
||||||
|
|
||||||
|
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||||
|
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||||
|
logger.debug(f"dingtalk: {message.data}")
|
||||||
|
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||||
|
abm = await self.convert_msg(im)
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
|
self.client = AstrCallbackClient()
|
||||||
|
|
||||||
|
credential = dingtalk_stream.Credential(self.client_id, self.client_secret)
|
||||||
|
client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger)
|
||||||
|
client.register_all_event_handler(MyEventHandler())
|
||||||
|
client.register_callback_handler(
|
||||||
|
dingtalk_stream.ChatbotMessage.TOPIC, self.client
|
||||||
|
)
|
||||||
|
self.client_ = client # 用于 websockets 的 client
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="dingtalk",
|
||||||
|
description="钉钉机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def convert_msg(
|
||||||
|
self, message: dingtalk_stream.ChatbotMessage
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.message = []
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.timestamp = int(message.create_at / 1000)
|
||||||
|
abm.type = (
|
||||||
|
MessageType.GROUP_MESSAGE
|
||||||
|
if message.conversation_type == "2"
|
||||||
|
else MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
user_id=message.sender_id, nickname=message.sender_nick
|
||||||
|
)
|
||||||
|
abm.self_id = message.chatbot_user_id
|
||||||
|
abm.message_id = message.message_id
|
||||||
|
abm.raw_message = message
|
||||||
|
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
if message.is_in_at_list:
|
||||||
|
abm.message.append(At(qq=abm.self_id))
|
||||||
|
abm.group_id = message.conversation_id
|
||||||
|
if self.unique_session:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.group_id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
|
message_type: str = message.message_type
|
||||||
|
match message_type:
|
||||||
|
case "text":
|
||||||
|
abm.message_str = message.text.content.strip()
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
case "richText":
|
||||||
|
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||||
|
contents: list[dict] = rtc.rich_text_list
|
||||||
|
for content in contents:
|
||||||
|
plains = ""
|
||||||
|
if "text" in content:
|
||||||
|
plains += content["text"]
|
||||||
|
abm.message.append(Plain(plains))
|
||||||
|
elif "type" in content and content["type"] == "picture":
|
||||||
|
f_path = await self.download_ding_file(
|
||||||
|
content["downloadCode"],
|
||||||
|
message.robot_code,
|
||||||
|
"jpg",
|
||||||
|
)
|
||||||
|
abm.message.append(Image.fromFileSystem(f_path))
|
||||||
|
case "audio":
|
||||||
|
pass
|
||||||
|
|
||||||
|
return abm # 别忘了返回转换后的消息对象
|
||||||
|
|
||||||
|
async def download_ding_file(
|
||||||
|
self, download_code: str, robot_code: str, ext: str
|
||||||
|
) -> str:
|
||||||
|
"""下载钉钉文件
|
||||||
|
|
||||||
|
:param access_token: 钉钉机器人的 access_token
|
||||||
|
:param download_code: 下载码
|
||||||
|
:param robot_code: 机器人码
|
||||||
|
:param ext: 文件后缀
|
||||||
|
:return: 文件路径
|
||||||
|
"""
|
||||||
|
access_token = await self.get_access_token()
|
||||||
|
headers = {
|
||||||
|
"x-acs-dingtalk-access-token": access_token,
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"downloadCode": download_code,
|
||||||
|
"robotCode": robot_code,
|
||||||
|
}
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"下载钉钉文件失败: {resp.status}, {await resp.text()}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
resp_data = await resp.json()
|
||||||
|
download_url = resp_data["data"]["downloadUrl"]
|
||||||
|
await download_file(download_url, f_path)
|
||||||
|
return f_path
|
||||||
|
|
||||||
|
async def get_access_token(self) -> str:
|
||||||
|
payload = {
|
||||||
|
"appKey": self.client_id,
|
||||||
|
"appSecret": self.client_secret,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return (await resp.json())["data"]["accessToken"]
|
||||||
|
|
||||||
|
async def handle_msg(self, abm: AstrBotMessage):
|
||||||
|
event = DingtalkMessageEvent(
|
||||||
|
message_str=abm.message_str,
|
||||||
|
message_obj=abm,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=abm.session_id,
|
||||||
|
client=self.client,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._event_queue.put_nowait(event)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
# await self.client_.start()
|
||||||
|
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
|
||||||
|
def start_client(loop: asyncio.AbstractEventLoop):
|
||||||
|
try:
|
||||||
|
self._shutdown_event = threading.Event()
|
||||||
|
task = loop.create_task(self.client_.start())
|
||||||
|
self._shutdown_event.wait()
|
||||||
|
if task.done():
|
||||||
|
task.result()
|
||||||
|
except Exception as e:
|
||||||
|
if "Graceful shutdown" in str(e):
|
||||||
|
logger.info("钉钉适配器已被优雅地关闭")
|
||||||
|
return
|
||||||
|
logger.error(f"钉钉机器人启动失败: {e}")
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, start_client, loop)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
def monkey_patch_close():
|
||||||
|
raise Exception("Graceful shutdown")
|
||||||
|
|
||||||
|
self.client_.open_connection = monkey_patch_close
|
||||||
|
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
return self.client
|
||||||
75
astrbot/core/platform/sources/dingtalk/dingtalk_event.py
Normal file
75
astrbot/core/platform/sources/dingtalk/dingtalk_event.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import asyncio
|
||||||
|
import dingtalk_stream
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
|
class DingtalkMessageEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str,
|
||||||
|
message_obj,
|
||||||
|
platform_meta,
|
||||||
|
session_id,
|
||||||
|
client: dingtalk_stream.ChatbotHandler,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def send_with_client(
|
||||||
|
self, client: dingtalk_stream.ChatbotHandler, message: MessageChain
|
||||||
|
):
|
||||||
|
for segment in message.chain:
|
||||||
|
if isinstance(segment, Comp.Plain):
|
||||||
|
segment.text = segment.text.strip()
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
client.reply_markdown,
|
||||||
|
"AstrBot",
|
||||||
|
segment.text,
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
)
|
||||||
|
elif isinstance(segment, Comp.Image):
|
||||||
|
markdown_str = ""
|
||||||
|
if segment.file and segment.file.startswith("file:///"):
|
||||||
|
logger.warning(
|
||||||
|
"dingtalk only support url image, not: " + segment.file
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif segment.file and segment.file.startswith("http"):
|
||||||
|
markdown_str += f"\n\n"
|
||||||
|
elif segment.file and segment.file.startswith("base64://"):
|
||||||
|
logger.warning("dingtalk only support url image, not base64")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"dingtalk only support url image, not: " + segment.file
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ret = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
client.reply_markdown,
|
||||||
|
"😄",
|
||||||
|
markdown_str,
|
||||||
|
self.message_obj.raw_message,
|
||||||
|
)
|
||||||
|
logger.debug(f"send image: {ret}")
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
await self.send_with_client(self.client, message)
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
@@ -1,17 +1,28 @@
|
|||||||
import threading
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
|
||||||
import quart
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import anyio
|
import anyio
|
||||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
import quart
|
||||||
from astrbot.api.message_components import Plain, Image, At, Record
|
|
||||||
from astrbot.api import logger, sp
|
from astrbot.api import logger, sp
|
||||||
from .downloader import GeweDownloader
|
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||||
|
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from .downloader import GeweDownloader
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .xml_data_parser import GeweDataParser
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
logger.warning(
|
||||||
|
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleGewechatClient:
|
class SimpleGewechatClient:
|
||||||
@@ -51,11 +62,11 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"]
|
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||||
)
|
)
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/astrbot-gewechat/file/<file_id>",
|
"/astrbot-gewechat/file/<file_token>",
|
||||||
view_func=self.handle_file,
|
view_func=self._handle_file,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,9 +81,15 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
self.userrealnames = {}
|
self.userrealnames = {}
|
||||||
|
|
||||||
self.stop = False
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
self.staged_files = {}
|
||||||
|
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||||
|
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
async def get_token_id(self):
|
async def get_token_id(self):
|
||||||
|
"""获取 Gewechat Token。"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
@@ -87,6 +104,15 @@ class SimpleGewechatClient:
|
|||||||
type_name = data["type_name"]
|
type_name = data["type_name"]
|
||||||
else:
|
else:
|
||||||
raise Exception("无法识别的消息类型")
|
raise Exception("无法识别的消息类型")
|
||||||
|
|
||||||
|
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||||
|
if type_name == "ModContacts":
|
||||||
|
logger.info("gewechat下发:ModContacts消息通知。")
|
||||||
|
return
|
||||||
|
if type_name == "DelContacts":
|
||||||
|
logger.info("gewechat下发:DelContacts消息通知。")
|
||||||
|
return
|
||||||
|
|
||||||
if type_name == "Offline":
|
if type_name == "Offline":
|
||||||
logger.critical("收到 gewechat 下线通知。")
|
logger.critical("收到 gewechat 下线通知。")
|
||||||
return
|
return
|
||||||
@@ -124,18 +150,25 @@ class SimpleGewechatClient:
|
|||||||
content = d["Content"]["string"] # 消息内容
|
content = d["Content"]["string"] # 消息内容
|
||||||
|
|
||||||
at_me = False
|
at_me = False
|
||||||
|
at_wxids = []
|
||||||
if "@chatroom" in from_user_name:
|
if "@chatroom" in from_user_name:
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
_t = content.split(":\n")
|
_t = content.split(":\n")
|
||||||
user_id = _t[0]
|
user_id = _t[0]
|
||||||
content = _t[1]
|
content = _t[1]
|
||||||
|
# at
|
||||||
|
msg_source = d["MsgSource"]
|
||||||
if "\u2005" in content:
|
if "\u2005" in content:
|
||||||
# at
|
# at
|
||||||
# content = content.split('\u2005')[1]
|
# content = content.split('\u2005')[1]
|
||||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||||
|
at_wxids = re.findall(
|
||||||
|
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||||
|
msg_source,
|
||||||
|
)
|
||||||
|
|
||||||
abm.group_id = from_user_name
|
abm.group_id = from_user_name
|
||||||
# at
|
|
||||||
msg_source = d["MsgSource"]
|
|
||||||
if (
|
if (
|
||||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||||
@@ -147,9 +180,13 @@ class SimpleGewechatClient:
|
|||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
user_id = from_user_name
|
user_id = from_user_name
|
||||||
|
|
||||||
|
# 检查消息是否由自己发送,若是则忽略
|
||||||
|
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||||
|
# if user_id == abm.self_id:
|
||||||
|
# logger.info("忽略自己发送的消息")
|
||||||
|
# return None
|
||||||
|
|
||||||
abm.message = []
|
abm.message = []
|
||||||
if at_me:
|
|
||||||
abm.message.insert(0, At(qq=abm.self_id))
|
|
||||||
|
|
||||||
# 解析用户真实名字
|
# 解析用户真实名字
|
||||||
user_real_name = "unknown"
|
user_real_name = "unknown"
|
||||||
@@ -173,11 +210,28 @@ class SimpleGewechatClient:
|
|||||||
else:
|
else:
|
||||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||||
else:
|
else:
|
||||||
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0]
|
try:
|
||||||
|
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||||
|
user_real_name = info["nickName"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||||
|
user_real_name = user_id
|
||||||
|
|
||||||
|
if at_me:
|
||||||
|
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||||
|
for wxid in at_wxids:
|
||||||
|
# 群聊里 At 其他人的列表
|
||||||
|
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||||
|
abm.message.append(At(qq=wxid, name=_username))
|
||||||
|
|
||||||
abm.sender = MessageMember(user_id, user_real_name)
|
abm.sender = MessageMember(user_id, user_real_name)
|
||||||
abm.raw_message = d
|
abm.raw_message = d
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
|
|
||||||
|
if user_id == "weixin":
|
||||||
|
# 忽略微信团队消息
|
||||||
|
return
|
||||||
|
|
||||||
# 不同消息类型
|
# 不同消息类型
|
||||||
match d["MsgType"]:
|
match d["MsgType"]:
|
||||||
case 1:
|
case 1:
|
||||||
@@ -195,18 +249,48 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
case 34:
|
case 34:
|
||||||
# 语音消息
|
# 语音消息
|
||||||
# data = await self.multimedia_downloader.download_voice(
|
|
||||||
# self.appid,
|
|
||||||
# content,
|
|
||||||
# abm.message_id
|
|
||||||
# )
|
|
||||||
# print(data)
|
|
||||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
file_path = os.path.join(
|
||||||
|
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||||
|
)
|
||||||
|
|
||||||
async with await anyio.open_file(file_path, "wb") as f:
|
async with await anyio.open_file(file_path, "wb") as f:
|
||||||
await f.write(voice_data)
|
await f.write(voice_data)
|
||||||
abm.message.append(Record(file=file_path, url=file_path))
|
abm.message.append(Record(file=file_path, url=file_path))
|
||||||
|
|
||||||
|
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||||
|
case 37: # 好友申请
|
||||||
|
logger.info("消息类型(37):好友申请")
|
||||||
|
case 42: # 名片
|
||||||
|
logger.info("消息类型(42):名片")
|
||||||
|
case 43: # 视频
|
||||||
|
video = Video(file="", cover=content)
|
||||||
|
abm.message.append(video)
|
||||||
|
case 47: # emoji
|
||||||
|
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||||
|
emoji = data_parser.parse_emoji()
|
||||||
|
abm.message.append(emoji)
|
||||||
|
case 48: # 地理位置
|
||||||
|
logger.info("消息类型(48):地理位置")
|
||||||
|
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||||
|
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||||
|
segments = data_parser.parse_mutil_49()
|
||||||
|
if segments:
|
||||||
|
abm.message.extend(segments)
|
||||||
|
for seg in segments:
|
||||||
|
if isinstance(seg, Plain):
|
||||||
|
abm.message_str += seg.text
|
||||||
|
case 51: # 帐号消息同步?
|
||||||
|
logger.info("消息类型(51):帐号消息同步?")
|
||||||
|
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||||
|
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||||
|
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||||
|
logger.info(
|
||||||
|
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||||
abm.raw_message = d
|
abm.raw_message = d
|
||||||
@@ -214,7 +298,7 @@ class SimpleGewechatClient:
|
|||||||
logger.debug(f"abm: {abm}")
|
logger.debug(f"abm: {abm}")
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
async def callback(self):
|
async def _callback(self):
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
logger.debug(f"收到 gewechat 回调: {data}")
|
logger.debug(f"收到 gewechat 回调: {data}")
|
||||||
|
|
||||||
@@ -236,9 +320,33 @@ class SimpleGewechatClient:
|
|||||||
|
|
||||||
return quart.jsonify({"r": "AstrBot ACK"})
|
return quart.jsonify({"r": "AstrBot ACK"})
|
||||||
|
|
||||||
async def handle_file(self, file_id):
|
async def _register_file(self, file_path: str) -> str:
|
||||||
file_path = f"data/temp/{file_id}"
|
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||||
return await quart.send_file(file_path)
|
|
||||||
|
Args:
|
||||||
|
file_path (str): 文件路径。
|
||||||
|
Returns:
|
||||||
|
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise Exception(f"文件不存在: {file_path}")
|
||||||
|
|
||||||
|
file_token = str(uuid.uuid4())
|
||||||
|
self.staged_files[file_token] = file_path
|
||||||
|
return file_token
|
||||||
|
|
||||||
|
async def _handle_file(self, file_token):
|
||||||
|
async with self.lock:
|
||||||
|
if file_token not in self.staged_files:
|
||||||
|
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||||
|
return quart.abort(404)
|
||||||
|
if not os.path.exists(self.staged_files[file_token]):
|
||||||
|
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||||
|
return quart.abort(404)
|
||||||
|
file_path = self.staged_files[file_token]
|
||||||
|
self.staged_files.pop(file_token, None)
|
||||||
|
return await quart.send_file(file_path)
|
||||||
|
|
||||||
async def _set_callback_url(self):
|
async def _set_callback_url(self):
|
||||||
logger.info("设置回调,请等待...")
|
logger.info("设置回调,请等待...")
|
||||||
@@ -262,17 +370,14 @@ class SimpleGewechatClient:
|
|||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder,
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
# TODO: use asyncio.Event
|
await self.shutdown_event.wait()
|
||||||
while not self.event_queue.closed and not self.stop: # noqa: ASYNC110
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("gewechat 适配器已关闭。")
|
|
||||||
|
|
||||||
async def check_online(self, appid: str):
|
async def check_online(self, appid: str):
|
||||||
# /login/checkOnline
|
"""检查 APPID 对应的设备是否在线。"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/login/checkOnline",
|
f"{self.base_url}/login/checkOnline",
|
||||||
@@ -283,6 +388,7 @@ class SimpleGewechatClient:
|
|||||||
return json_blob["data"]
|
return json_blob["data"]
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
|
"""登出 gewechat。"""
|
||||||
if self.appid:
|
if self.appid:
|
||||||
online = await self.check_online(self.appid)
|
online = await self.check_online(self.appid)
|
||||||
if online:
|
if online:
|
||||||
@@ -296,6 +402,7 @@ class SimpleGewechatClient:
|
|||||||
logger.info(f"登出结果: {json_blob}")
|
logger.info(f"登出结果: {json_blob}")
|
||||||
|
|
||||||
async def login(self):
|
async def login(self):
|
||||||
|
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||||
if self.token is None:
|
if self.token is None:
|
||||||
await self.get_token_id()
|
await self.get_token_id()
|
||||||
|
|
||||||
@@ -304,32 +411,49 @@ class SimpleGewechatClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.appid:
|
if self.appid:
|
||||||
online = await self.check_online(self.appid)
|
try:
|
||||||
if online:
|
online = await self.check_online(self.appid)
|
||||||
logger.info(f"APPID: {self.appid} 已在线")
|
if online:
|
||||||
return
|
logger.info(f"APPID: {self.appid} 已在线")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查在线状态失败: {e}")
|
||||||
|
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||||
|
self.appid = None
|
||||||
|
|
||||||
payload = {"appId": self.appid}
|
payload = {"appId": self.appid}
|
||||||
|
|
||||||
if self.appid:
|
if self.appid:
|
||||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
try:
|
||||||
async with session.post(
|
async with aiohttp.ClientSession() as session:
|
||||||
f"{self.base_url}/login/getLoginQrCode",
|
async with session.post(
|
||||||
headers=self.headers,
|
f"{self.base_url}/login/getLoginQrCode",
|
||||||
json=payload,
|
headers=self.headers,
|
||||||
) as resp:
|
json=payload,
|
||||||
json_blob = await resp.json()
|
) as resp:
|
||||||
if json_blob["ret"] != 200:
|
json_blob = await resp.json()
|
||||||
raise Exception(f"获取二维码失败: {json_blob}")
|
if json_blob["ret"] != 200:
|
||||||
qr_data = json_blob["data"]["qrData"]
|
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||||
qr_uuid = json_blob["data"]["uuid"]
|
if "设备不存在" in error_msg:
|
||||||
appid = json_blob["data"]["appId"]
|
logger.error(
|
||||||
logger.info(f"APPID: {appid}")
|
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||||
logger.warning(
|
)
|
||||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||||
)
|
self.appid = None
|
||||||
|
return await self.login()
|
||||||
|
else:
|
||||||
|
raise Exception(f"获取二维码失败: {json_blob}")
|
||||||
|
qr_data = json_blob["data"]["qrData"]
|
||||||
|
qr_uuid = json_blob["data"]["uuid"]
|
||||||
|
appid = json_blob["data"]["appId"]
|
||||||
|
logger.info(f"APPID: {appid}")
|
||||||
|
logger.warning(
|
||||||
|
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
# 执行登录
|
# 执行登录
|
||||||
retry_cnt = 64
|
retry_cnt = 64
|
||||||
@@ -338,8 +462,10 @@ class SimpleGewechatClient:
|
|||||||
retry_cnt -= 1
|
retry_cnt -= 1
|
||||||
|
|
||||||
# 需要验证码
|
# 需要验证码
|
||||||
if os.path.exists("data/temp/gewe_code"):
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
with open("data/temp/gewe_code", "r") as f:
|
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||||
|
if os.path.exists(code_file_path):
|
||||||
|
with open(code_file_path, "r") as f:
|
||||||
code = f.read().strip()
|
code = f.read().strip()
|
||||||
if not code:
|
if not code:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -350,9 +476,9 @@ class SimpleGewechatClient:
|
|||||||
payload["captchCode"] = code
|
payload["captchCode"] = code
|
||||||
logger.info(f"使用验证码: {code}")
|
logger.info(f"使用验证码: {code}")
|
||||||
try:
|
try:
|
||||||
os.remove("data/temp/gewe_code")
|
os.remove(code_file_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
@@ -372,17 +498,18 @@ class SimpleGewechatClient:
|
|||||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
status = json_blob["data"]["status"]
|
if "status" in json_blob["data"]:
|
||||||
nickname = json_blob["data"].get("nickName", "")
|
status = json_blob["data"]["status"]
|
||||||
if status == 1:
|
nickname = json_blob["data"].get("nickName", "")
|
||||||
logger.info(f"等待确认...{nickname}")
|
if status == 1:
|
||||||
elif status == 2:
|
logger.info(f"等待确认...{nickname}")
|
||||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
elif status == 2:
|
||||||
break
|
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||||
elif status == 0:
|
break
|
||||||
logger.info("等待扫码...")
|
elif status == 0:
|
||||||
else:
|
logger.info("等待扫码...")
|
||||||
logger.warning(f"未知状态: {status}")
|
else:
|
||||||
|
logger.warning(f"未知状态: {status}")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
if appid:
|
if appid:
|
||||||
@@ -390,9 +517,18 @@ class SimpleGewechatClient:
|
|||||||
self.appid = appid
|
self.appid = appid
|
||||||
logger.info(f"已保存 APPID: {appid}")
|
logger.info(f"已保存 APPID: {appid}")
|
||||||
|
|
||||||
"""API"""
|
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||||
|
"""
|
||||||
|
|
||||||
async def get_chatroom_member_list(self, chatroom_wxid: str):
|
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||||
|
"""获取群成员列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||||
|
"""
|
||||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@@ -405,6 +541,7 @@ class SimpleGewechatClient:
|
|||||||
return json_blob["data"]
|
return json_blob["data"]
|
||||||
|
|
||||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||||
|
"""发送纯文本消息"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -421,6 +558,7 @@ class SimpleGewechatClient:
|
|||||||
logger.debug(f"发送消息结果: {json_blob}")
|
logger.debug(f"发送消息结果: {json_blob}")
|
||||||
|
|
||||||
async def post_image(self, to_wxid, image_url: str):
|
async def post_image(self, to_wxid, image_url: str):
|
||||||
|
"""发送图片消息"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -434,7 +572,79 @@ class SimpleGewechatClient:
|
|||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送图片结果: {json_blob}")
|
logger.debug(f"发送图片结果: {json_blob}")
|
||||||
|
|
||||||
|
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||||
|
"""发送emoji消息"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"emojiMd5": emoji_md5,
|
||||||
|
"emojiSize": emoji_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||||
|
try:
|
||||||
|
if emoji_md5 != "" and emoji_size != "":
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postEmoji",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.info(
|
||||||
|
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.post_image(to_wxid, cdnurl)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
async def post_video(
|
||||||
|
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"videoUrl": video_url,
|
||||||
|
"thumbUrl": thumb_url,
|
||||||
|
"videoDuration": video_duration,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"发送视频结果: {json_blob}")
|
||||||
|
|
||||||
|
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||||
|
"""转发视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_wxid (str): 发送给谁
|
||||||
|
cnd_xml (str): 视频消息的cdn信息
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"toWxid": to_wxid,
|
||||||
|
"xml": cnd_xml,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/message/forwardVideo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"转发视频结果: {json_blob}")
|
||||||
|
|
||||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||||
|
"""发送语音信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_url (str): 语音文件的网络链接
|
||||||
|
voice_duration (int): 语音时长,毫秒
|
||||||
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -449,9 +659,16 @@ class SimpleGewechatClient:
|
|||||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送语音结果: {json_blob}")
|
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||||
|
|
||||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||||
|
"""发送文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_wxid (string): 微信ID
|
||||||
|
file_url (str): 文件的网络链接
|
||||||
|
file_name (str): 文件名
|
||||||
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"appId": self.appid,
|
"appId": self.appid,
|
||||||
"toWxid": to_wxid,
|
"toWxid": to_wxid,
|
||||||
@@ -465,3 +682,131 @@ class SimpleGewechatClient:
|
|||||||
) as resp:
|
) as resp:
|
||||||
json_blob = await resp.json()
|
json_blob = await resp.json()
|
||||||
logger.debug(f"发送文件结果: {json_blob}")
|
logger.debug(f"发送文件结果: {json_blob}")
|
||||||
|
|
||||||
|
async def add_friend(self, v3: str, v4: str, content: str):
|
||||||
|
"""申请添加好友"""
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"scene": 3,
|
||||||
|
"content": content,
|
||||||
|
"v4": v4,
|
||||||
|
"v3": v3,
|
||||||
|
"option": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/addContacts",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomInfo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_group_member(self, group_id: str):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/getChatroomMemberList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def accept_group_invite(self, url: str):
|
||||||
|
"""同意进群"""
|
||||||
|
payload = {"appId": self.appid, "url": url}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/agreeJoinRoom",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def add_group_member_to_friend(
|
||||||
|
self, group_id: str, to_wxid: str, content: str
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"chatroomId": group_id,
|
||||||
|
"content": content,
|
||||||
|
"memberWxid": to_wxid,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_user_or_group_info(self, *ids):
|
||||||
|
"""
|
||||||
|
获取用户或群组信息。
|
||||||
|
|
||||||
|
:param ids: 可变数量的 wxid 参数
|
||||||
|
"""
|
||||||
|
|
||||||
|
wxids_str = list(ids)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"appId": self.appid,
|
||||||
|
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/getDetailInfo",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取群信息结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|
||||||
|
async def get_contacts_list(self):
|
||||||
|
"""
|
||||||
|
获取通讯录列表
|
||||||
|
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||||
|
"""
|
||||||
|
payload = {"appId": self.appid}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/contacts/fetchContactsList",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
json_blob = await resp.json()
|
||||||
|
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||||
|
return json_blob
|
||||||
|
|||||||
@@ -39,3 +39,17 @@ class GeweDownloader:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
raise Exception("无法下载图片")
|
raise Exception("无法下载图片")
|
||||||
|
|
||||||
|
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||||
|
"""下载emoji"""
|
||||||
|
try:
|
||||||
|
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||||
|
|
||||||
|
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||||
|
data = await self._post_json(
|
||||||
|
self.base_url, "/message/downloadEmojiMd5", payload
|
||||||
|
)
|
||||||
|
json_blob = json.loads(data)
|
||||||
|
return json_blob
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(f"gewe download emoji: {e}")
|
||||||
|
|||||||
@@ -1,14 +1,27 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
import wave
|
import wave
|
||||||
import uuid
|
import uuid
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||||
from astrbot.api.message_components import Plain, Image, Record, At, File
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
Record,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Video,
|
||||||
|
WechatEmoji as Emoji,
|
||||||
|
)
|
||||||
from .client import SimpleGewechatClient
|
from .client import SimpleGewechatClient
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
|
||||||
def get_wav_duration(file_path):
|
def get_wav_duration(file_path):
|
||||||
@@ -70,39 +83,84 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
await client.post_text(**payload)
|
await client.post_text(**payload)
|
||||||
|
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
img_url = comp.file
|
img_path = await comp.convert_to_file_path()
|
||||||
img_path = ""
|
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||||
if img_url.startswith("file:///"):
|
token = await client._register_file(img_path)
|
||||||
img_path = img_url[8:]
|
img_url = f"{client.file_server_url}/{token}"
|
||||||
elif comp.file and comp.file.startswith("http"):
|
|
||||||
img_path = await download_image_by_url(comp.file)
|
|
||||||
else:
|
|
||||||
img_path = img_url
|
|
||||||
|
|
||||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
|
||||||
temp_directory = os.path.abspath("data/temp")
|
|
||||||
img_path = os.path.abspath(img_path)
|
|
||||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
|
||||||
with open(img_path, "rb") as f:
|
|
||||||
img_path = save_temp_img(f.read())
|
|
||||||
|
|
||||||
file_id = os.path.basename(img_path)
|
|
||||||
img_url = f"{client.file_server_url}/{file_id}"
|
|
||||||
logger.debug(f"gewe callback img url: {img_url}")
|
logger.debug(f"gewe callback img url: {img_url}")
|
||||||
await client.post_image(to_wxid, img_url)
|
await client.post_image(to_wxid, img_url)
|
||||||
|
elif isinstance(comp, Video):
|
||||||
|
if comp.cover != "":
|
||||||
|
await client.forward_video(to_wxid, comp.cover)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFmpeg
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
logger.error(
|
||||||
|
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||||
|
)
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_url = comp.file
|
||||||
|
# 根据 url 下载视频
|
||||||
|
if video_url.startswith("http"):
|
||||||
|
video_filename = f"{uuid.uuid4()}.mp4"
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
video_path = os.path.join(temp_dir, video_filename)
|
||||||
|
await download_file(video_url, video_path)
|
||||||
|
else:
|
||||||
|
video_path = video_url
|
||||||
|
|
||||||
|
video_token = await client._register_file(video_path)
|
||||||
|
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||||
|
|
||||||
|
# 获取视频第一帧
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
thumb_path = os.path.join(
|
||||||
|
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||||
|
)
|
||||||
|
|
||||||
|
video_path = video_path.replace(" ", "\\ ")
|
||||||
|
try:
|
||||||
|
ff = FFmpeg()
|
||||||
|
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||||
|
ff.options(command)
|
||||||
|
thumb_token = await client._register_file(thumb_path)
|
||||||
|
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取视频第一帧失败: {e}")
|
||||||
|
|
||||||
|
# 获取视频时长
|
||||||
|
try:
|
||||||
|
from pyffmpeg import FFprobe
|
||||||
|
|
||||||
|
# 创建 FFprobe 实例
|
||||||
|
ffprobe = FFprobe(video_url)
|
||||||
|
# 获取时长字符串
|
||||||
|
duration_str = ffprobe.duration
|
||||||
|
# 处理时长字符串
|
||||||
|
video_duration = float(duration_str.replace(":", ""))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取时长失败: {e}")
|
||||||
|
video_duration = 10
|
||||||
|
|
||||||
|
# 发送视频
|
||||||
|
await client.post_video(
|
||||||
|
to_wxid, video_callback_url, thumb_url, video_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除临时缩略图文件
|
||||||
|
if os.path.exists(thumb_path):
|
||||||
|
os.remove(thumb_path)
|
||||||
elif isinstance(comp, Record):
|
elif isinstance(comp, Record):
|
||||||
# 默认已经存在 data/temp 中
|
# 默认已经存在 data/temp 中
|
||||||
record_url = comp.file
|
record_url = comp.file
|
||||||
record_path = ""
|
record_path = await comp.convert_to_file_path()
|
||||||
|
|
||||||
if record_url.startswith("file:///"):
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
record_path = record_url[8:]
|
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||||
elif record_url.startswith("http"):
|
|
||||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
|
||||||
else:
|
|
||||||
record_path = record_url
|
|
||||||
|
|
||||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
|
||||||
try:
|
try:
|
||||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -111,8 +169,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||||
if duration == 0:
|
if duration == 0:
|
||||||
duration = get_wav_duration(record_path)
|
duration = get_wav_duration(record_path)
|
||||||
file_id = os.path.basename(silk_path)
|
token = await client._register_file(silk_path)
|
||||||
record_url = f"{client.file_server_url}/{file_id}"
|
record_url = f"{client.file_server_url}/{token}"
|
||||||
logger.debug(f"gewe callback record url: {record_url}")
|
logger.debug(f"gewe callback record url: {record_url}")
|
||||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||||
elif isinstance(comp, File):
|
elif isinstance(comp, File):
|
||||||
@@ -121,14 +179,19 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
if file_path.startswith("file:///"):
|
if file_path.startswith("file:///"):
|
||||||
file_path = file_path[8:]
|
file_path = file_path[8:]
|
||||||
elif file_path.startswith("http"):
|
elif file_path.startswith("http"):
|
||||||
await download_file(file_path, f"data/temp/{file_name}")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
temp_file_path = os.path.join(temp_dir, file_name)
|
||||||
|
await download_file(file_path, temp_file_path)
|
||||||
|
file_path = temp_file_path
|
||||||
else:
|
else:
|
||||||
file_path = file_path
|
file_path = file_path
|
||||||
|
|
||||||
file_id = os.path.basename(file_path)
|
token = await client._register_file(file_path)
|
||||||
file_url = f"{client.file_server_url}/{file_id}"
|
file_url = f"{client.file_server_url}/{token}"
|
||||||
logger.debug(f"gewe callback file url: {file_url}")
|
logger.debug(f"gewe callback file url: {file_url}")
|
||||||
await client.post_file(to_wxid, file_url, file_id)
|
await client.post_file(to_wxid, file_url, file_name)
|
||||||
|
elif isinstance(comp, Emoji):
|
||||||
|
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||||
elif isinstance(comp, At):
|
elif isinstance(comp, At):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@@ -138,3 +201,64 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
|||||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||||
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def get_group(self, group_id=None, **kwargs):
|
||||||
|
# 确定有效的 group_id
|
||||||
|
if group_id is None:
|
||||||
|
group_id = self.get_group_id()
|
||||||
|
|
||||||
|
if not group_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = await self.client.get_group(group_id)
|
||||||
|
data: dict = res["data"]
|
||||||
|
|
||||||
|
if not data["chatroomId"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
members = [
|
||||||
|
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||||
|
for member in data.get("memberList", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return Group(
|
||||||
|
group_id=data["chatroomId"],
|
||||||
|
group_name=data.get("nickName"),
|
||||||
|
group_avatar=data.get("smallHeadImgUrl"),
|
||||||
|
group_owner=data.get("chatRoomOwner"),
|
||||||
|
members=members,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||||
|
):
|
||||||
|
if not use_fallback:
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
buffer += comp.text
|
||||||
|
if any(p in buffer for p in "。?!~…"):
|
||||||
|
buffer = await self.process_buffer(buffer, pattern)
|
||||||
|
else:
|
||||||
|
await self.send(MessageChain(chain=[comp]))
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
|
||||||
|
if buffer.strip():
|
||||||
|
await self.send(MessageChain([Plain(buffer)]))
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
|||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
from .gewechat_event import GewechatPlatformEvent
|
from .gewechat_event import GewechatPlatformEvent
|
||||||
from .client import SimpleGewechatClient
|
from .client import SimpleGewechatClient
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -59,13 +60,18 @@ class GewechatPlatformAdapter(Platform):
|
|||||||
@override
|
@override
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"gewechat",
|
name="gewechat",
|
||||||
"基于 gewechat 的 Wechat 适配器",
|
description="基于 gewechat 的 Wechat 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
self.client.stop = True
|
self.client.shutdown_event.set()
|
||||||
await asyncio.sleep(1)
|
try:
|
||||||
|
await self.client.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
await self.client.logout()
|
await self.client.logout()
|
||||||
|
|||||||
110
astrbot/core/platform/sources/gewechat/xml_data_parser.py
Normal file
110
astrbot/core/platform/sources/gewechat/xml_data_parser.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from defusedxml import ElementTree as eT
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.message_components import (
|
||||||
|
WechatEmoji as Emoji,
|
||||||
|
Reply,
|
||||||
|
Plain,
|
||||||
|
BaseMessageComponent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeweDataParser:
|
||||||
|
def __init__(self, data, is_private_chat):
|
||||||
|
self.data = data
|
||||||
|
self.is_private_chat = is_private_chat
|
||||||
|
|
||||||
|
def _format_to_xml(self):
|
||||||
|
return eT.fromstring(self.data)
|
||||||
|
|
||||||
|
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||||
|
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||||
|
if appmsg_type is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
match appmsg_type.text:
|
||||||
|
case "57":
|
||||||
|
return self.parse_reply()
|
||||||
|
|
||||||
|
def parse_emoji(self) -> Emoji | None:
|
||||||
|
try:
|
||||||
|
emoji_element = self._format_to_xml().find(".//emoji")
|
||||||
|
# 提取 md5 和 len 属性
|
||||||
|
if emoji_element is not None:
|
||||||
|
md5_value = emoji_element.get("md5")
|
||||||
|
emoji_size = emoji_element.get("len")
|
||||||
|
cdnurl = emoji_element.get("cdnurl")
|
||||||
|
|
||||||
|
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||||
|
|
||||||
|
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||||
|
"""解析引用消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
replied_id = -1
|
||||||
|
replied_uid = 0
|
||||||
|
replied_nickname = ""
|
||||||
|
replied_content = "" # 被引用者说的内容
|
||||||
|
content = "" # 引用者说的内容
|
||||||
|
|
||||||
|
root = self._format_to_xml()
|
||||||
|
refermsg = root.find(".//refermsg")
|
||||||
|
if refermsg is not None:
|
||||||
|
# 被引用的信息
|
||||||
|
svrid = refermsg.find("svrid")
|
||||||
|
fromusr = refermsg.find("fromusr")
|
||||||
|
displayname = refermsg.find("displayname")
|
||||||
|
refermsg_content = refermsg.find("content")
|
||||||
|
if svrid is not None:
|
||||||
|
replied_id = svrid.text
|
||||||
|
if fromusr is not None:
|
||||||
|
replied_uid = fromusr.text
|
||||||
|
if displayname is not None:
|
||||||
|
replied_nickname = displayname.text
|
||||||
|
if refermsg_content is not None:
|
||||||
|
# 处理引用嵌套,包括嵌套公众号消息
|
||||||
|
if refermsg_content.text.startswith(
|
||||||
|
"<msg>"
|
||||||
|
) or refermsg_content.text.startswith("<?xml"):
|
||||||
|
try:
|
||||||
|
logger.debug("gewechat: Reference message is nested")
|
||||||
|
refer_root = eT.fromstring(refermsg_content.text)
|
||||||
|
img = refer_root.find("img")
|
||||||
|
if img is not None:
|
||||||
|
replied_content = "[图片]"
|
||||||
|
else:
|
||||||
|
app_msg = refer_root.find("appmsg")
|
||||||
|
refermsg_content_title = app_msg.find("title")
|
||||||
|
logger.debug(
|
||||||
|
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||||
|
)
|
||||||
|
replied_content = refermsg_content_title.text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: nested failed, {e}")
|
||||||
|
# 处理异常情况
|
||||||
|
replied_content = refermsg_content.text
|
||||||
|
else:
|
||||||
|
replied_content = refermsg_content.text
|
||||||
|
|
||||||
|
# 提取引用者说的内容
|
||||||
|
title = root.find(".//appmsg/title")
|
||||||
|
if title is not None:
|
||||||
|
content = title.text
|
||||||
|
|
||||||
|
reply_seg = Reply(
|
||||||
|
id=replied_id,
|
||||||
|
chain=[Plain(replied_content)],
|
||||||
|
sender_id=replied_uid,
|
||||||
|
sender_nickname=replied_nickname,
|
||||||
|
message_str=replied_content,
|
||||||
|
)
|
||||||
|
plain_seg = Plain(content)
|
||||||
|
return [reply_seg, plain_seg]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||||
@@ -2,6 +2,8 @@ import base64
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
Platform,
|
Platform,
|
||||||
@@ -11,7 +13,6 @@ from astrbot.api.platform import (
|
|||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
)
|
)
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api.message_components import Image, Plain, At
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from .lark_event import LarkMessageEvent
|
from .lark_event import LarkMessageEvent
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
@@ -66,12 +67,47 @@ class LarkPlatformAdapter(Platform):
|
|||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
):
|
):
|
||||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||||
|
wrapped = {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "",
|
||||||
|
"content": res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||||
|
id_type = "chat_id"
|
||||||
|
if "%" in session.session_id:
|
||||||
|
session.session_id = session.session_id.split("%")[1]
|
||||||
|
else:
|
||||||
|
id_type = "open_id"
|
||||||
|
|
||||||
|
request = (
|
||||||
|
CreateMessageRequest.builder()
|
||||||
|
.receive_id_type(id_type)
|
||||||
|
.request_body(
|
||||||
|
CreateMessageRequestBody.builder()
|
||||||
|
.receive_id(session.session_id)
|
||||||
|
.content(json.dumps(wrapped))
|
||||||
|
.msg_type("post")
|
||||||
|
.uuid(str(uuid.uuid4()))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.lark_api.im.v1.message.acreate(request)
|
||||||
|
|
||||||
|
if not response.success():
|
||||||
|
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||||
|
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"lark",
|
name="lark",
|
||||||
"飞书机器人官方 API 适配器",
|
description="飞书机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||||
@@ -92,7 +128,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
at_list = {}
|
at_list = {}
|
||||||
if message.mentions:
|
if message.mentions:
|
||||||
for m in message.mentions:
|
for m in message.mentions:
|
||||||
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
|
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||||
if m.name == self.bot_name:
|
if m.name == self.bot_name:
|
||||||
abm.self_id = m.id.open_id
|
abm.self_id = m.id.open_id
|
||||||
|
|
||||||
@@ -111,7 +147,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
if s in at_list:
|
if s in at_list:
|
||||||
abm.message.append(at_list[s])
|
abm.message.append(at_list[s])
|
||||||
else:
|
else:
|
||||||
abm.message.append(Plain(parts[i].strip()))
|
abm.message.append(Comp.Plain(parts[i].strip()))
|
||||||
elif message.message_type == "post":
|
elif message.message_type == "post":
|
||||||
_ls = []
|
_ls = []
|
||||||
|
|
||||||
@@ -132,7 +168,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
if comp["tag"] == "at":
|
if comp["tag"] == "at":
|
||||||
abm.message.append(at_list[comp["user_id"]])
|
abm.message.append(at_list[comp["user_id"]])
|
||||||
elif comp["tag"] == "text" and comp["text"].strip():
|
elif comp["tag"] == "text" and comp["text"].strip():
|
||||||
abm.message.append(Plain(comp["text"].strip()))
|
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||||
elif comp["tag"] == "img":
|
elif comp["tag"] == "img":
|
||||||
image_key = comp["image_key"]
|
image_key = comp["image_key"]
|
||||||
request = (
|
request = (
|
||||||
@@ -147,10 +183,10 @@ class LarkPlatformAdapter(Platform):
|
|||||||
logger.error(f"无法下载飞书图片: {image_key}")
|
logger.error(f"无法下载飞书图片: {image_key}")
|
||||||
image_bytes = response.file.read()
|
image_bytes = response.file.read()
|
||||||
image_base64 = base64.b64encode(image_bytes).decode()
|
image_base64 = base64.b64encode(image_bytes).decode()
|
||||||
abm.message.append(Image.fromBase64(image_base64))
|
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||||
|
|
||||||
for comp in abm.message:
|
for comp in abm.message:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Comp.Plain):
|
||||||
abm.message_str += comp.text
|
abm.message_str += comp.text
|
||||||
abm.message_id = message.message_id
|
abm.message_id = message.message_id
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
@@ -165,7 +201,10 @@ class LarkPlatformAdapter(Platform):
|
|||||||
else:
|
else:
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
else:
|
else:
|
||||||
abm.session_id = abm.sender.user_id
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||||
|
else:
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
logger.debug(abm)
|
logger.debug(abm)
|
||||||
await self.handle_msg(abm)
|
await self.handle_msg(abm)
|
||||||
@@ -185,5 +224,9 @@ class LarkPlatformAdapter(Platform):
|
|||||||
# self.client.start()
|
# self.client.start()
|
||||||
await self.client._connect()
|
await self.client._connect()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
await self.client._disconnect()
|
||||||
|
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||||
|
|
||||||
def get_client(self) -> lark.Client:
|
def get_client(self) -> lark.Client:
|
||||||
return self.client
|
return self.client
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import base64
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
|
from io import BytesIO
|
||||||
from typing import List
|
from typing import List
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from lark_oapi.api.im.v1 import *
|
from lark_oapi.api.im.v1 import *
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
|
||||||
class LarkMessageEvent(AstrMessageEvent):
|
class LarkMessageEvent(AstrMessageEvent):
|
||||||
@@ -27,22 +31,33 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
||||||
elif isinstance(comp, AstrBotImage):
|
elif isinstance(comp, AstrBotImage):
|
||||||
file_path = ""
|
file_path = ""
|
||||||
|
image_file = None
|
||||||
|
|
||||||
if comp.file and comp.file.startswith("file:///"):
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
file_path = comp.file.replace("file:///", "")
|
file_path = comp.file.replace("file:///", "")
|
||||||
elif comp.file and comp.file.startswith("http"):
|
elif comp.file and comp.file.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(comp.file)
|
image_file_path = await download_image_by_url(comp.file)
|
||||||
file_path = image_file_path
|
file_path = image_file_path
|
||||||
elif comp.file and comp.file.startswith("base64://"):
|
elif comp.file and comp.file.startswith("base64://"):
|
||||||
pass
|
base64_str = comp.file.removeprefix("base64://")
|
||||||
|
image_data = base64.b64decode(base64_str)
|
||||||
|
# save as temp file
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(BytesIO(image_data).getvalue())
|
||||||
else:
|
else:
|
||||||
file_path = comp.file
|
file_path = comp.file
|
||||||
|
|
||||||
|
if image_file is None:
|
||||||
|
image_file = open(file_path, "rb")
|
||||||
|
|
||||||
request = (
|
request = (
|
||||||
CreateImageRequest.builder()
|
CreateImageRequest.builder()
|
||||||
.request_body(
|
.request_body(
|
||||||
CreateImageRequestBody.builder()
|
CreateImageRequestBody.builder()
|
||||||
.image_type("message")
|
.image_type("message")
|
||||||
.image(open(file_path, "rb"))
|
.image(image_file)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
.build()
|
.build()
|
||||||
@@ -51,7 +66,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||||
image_key = response.data.image_key
|
image_key = response.data.image_key
|
||||||
print(image_key)
|
logger.debug(image_key)
|
||||||
ret.append(_stage)
|
ret.append(_stage)
|
||||||
ret.append([{"tag": "img", "image_key": image_key}])
|
ret.append([{"tag": "img", "image_key": image_key}])
|
||||||
_stage.clear()
|
_stage.clear()
|
||||||
@@ -91,3 +106,16 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import botpy
|
|||||||
import botpy.message
|
import botpy.message
|
||||||
import botpy.types
|
import botpy.types
|
||||||
import botpy.types.message
|
import botpy.types.message
|
||||||
|
import asyncio
|
||||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
@@ -9,6 +10,8 @@ from astrbot.api.message_components import Plain, Image
|
|||||||
from botpy import Client
|
from botpy import Client
|
||||||
from botpy.http import Route
|
from botpy.http import Route
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
|
from botpy.types import message
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||||
@@ -30,8 +33,45 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
self.send_buffer.chain.extend(message.chain)
|
self.send_buffer.chain.extend(message.chain)
|
||||||
|
|
||||||
async def _post_send(self):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
"""QQ 官方 API 仅支持回复一次"""
|
"""流式输出仅支持消息列表私聊"""
|
||||||
|
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||||
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
|
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||||
|
try:
|
||||||
|
async for chain in generator:
|
||||||
|
source = self.message_obj.raw_message
|
||||||
|
if not self.send_buffer:
|
||||||
|
self.send_buffer = chain
|
||||||
|
else:
|
||||||
|
self.send_buffer.chain.extend(chain.chain)
|
||||||
|
|
||||||
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
|
# 真流式传输
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
|
if time_since_last_edit >= throttle_interval:
|
||||||
|
ret = await self._post_send(stream=stream_payload)
|
||||||
|
stream_payload["index"] += 1
|
||||||
|
stream_payload["id"] = ret["id"]
|
||||||
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
if isinstance(source, botpy.message.C2CMessage):
|
||||||
|
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||||
|
stream_payload["state"] = 10
|
||||||
|
ret = await self._post_send(stream=stream_payload)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||||
|
self.send_buffer = None
|
||||||
|
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
async def _post_send(self, stream: dict = None):
|
||||||
|
if not self.send_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
source = self.message_obj.raw_message
|
source = self.message_obj.raw_message
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
source,
|
source,
|
||||||
@@ -57,6 +97,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
"msg_id": self.message_obj.message_id,
|
"msg_id": self.message_obj.message_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
|
||||||
|
payload["msg_seq"] = random.randint(1, 10000)
|
||||||
|
|
||||||
match type(source):
|
match type(source):
|
||||||
case botpy.message.GroupMessage:
|
case botpy.message.GroupMessage:
|
||||||
if image_base64:
|
if image_base64:
|
||||||
@@ -65,7 +108,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
await self.bot.api.post_group_message(
|
ret = await self.bot.api.post_group_message(
|
||||||
group_openid=source.group_openid, **payload
|
group_openid=source.group_openid, **payload
|
||||||
)
|
)
|
||||||
case botpy.message.C2CMessage:
|
case botpy.message.C2CMessage:
|
||||||
@@ -75,22 +118,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
await self.bot.api.post_c2c_message(
|
if stream:
|
||||||
openid=source.author.user_openid, **payload
|
ret = await self.post_c2c_message(
|
||||||
)
|
openid=source.author.user_openid,
|
||||||
|
**payload,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ret = await self.post_c2c_message(
|
||||||
|
openid=source.author.user_openid, **payload
|
||||||
|
)
|
||||||
|
logger.debug(f"Message sent to C2C: {ret}")
|
||||||
case botpy.message.Message:
|
case botpy.message.Message:
|
||||||
if image_path:
|
if image_path:
|
||||||
payload["file_image"] = image_path
|
payload["file_image"] = image_path
|
||||||
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
|
ret = await self.bot.api.post_message(
|
||||||
|
channel_id=source.channel_id, **payload
|
||||||
|
)
|
||||||
case botpy.message.DirectMessage:
|
case botpy.message.DirectMessage:
|
||||||
if image_path:
|
if image_path:
|
||||||
payload["file_image"] = image_path
|
payload["file_image"] = image_path
|
||||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||||
|
|
||||||
await super().send(self.send_buffer)
|
await super().send(self.send_buffer)
|
||||||
|
|
||||||
self.send_buffer = None
|
self.send_buffer = None
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
async def upload_group_and_c2c_image(
|
async def upload_group_and_c2c_image(
|
||||||
self, image_base64: str, file_type: int, **kwargs
|
self, image_base64: str, file_type: int, **kwargs
|
||||||
) -> botpy.types.message.Media:
|
) -> botpy.types.message.Media:
|
||||||
@@ -112,6 +167,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
|
async def post_c2c_message(
|
||||||
|
self,
|
||||||
|
openid: str,
|
||||||
|
msg_type: int = 0,
|
||||||
|
content: str = None,
|
||||||
|
embed: message.Embed = None,
|
||||||
|
ark: message.Ark = None,
|
||||||
|
message_reference: message.Reference = None,
|
||||||
|
media: message.Media = None,
|
||||||
|
msg_id: str = None,
|
||||||
|
msg_seq: str = 1,
|
||||||
|
event_id: str = None,
|
||||||
|
markdown: message.MarkdownPayload = None,
|
||||||
|
keyboard: message.Keyboard = None,
|
||||||
|
stream: dict = None,
|
||||||
|
) -> message.Message:
|
||||||
|
payload = locals()
|
||||||
|
payload.pop("self", None)
|
||||||
|
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||||
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _parse_to_qqofficial(message: MessageChain):
|
async def _parse_to_qqofficial(message: MessageChain):
|
||||||
plain_text = ""
|
plain_text = ""
|
||||||
@@ -122,16 +198,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
plain_text += i.text
|
plain_text += i.text
|
||||||
elif isinstance(i, Image) and not image_base64:
|
elif isinstance(i, Image) and not image_base64:
|
||||||
if i.file and i.file.startswith("file:///"):
|
if i.file and i.file.startswith("file:///"):
|
||||||
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
|
image_base64 = file_to_base64(i.file[8:])
|
||||||
image_file_path = i.file[8:]
|
image_file_path = i.file[8:]
|
||||||
elif i.file and i.file.startswith("http"):
|
elif i.file and i.file.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(i.file)
|
image_file_path = await download_image_by_url(i.file)
|
||||||
image_base64 = file_to_base64(image_file_path).replace(
|
image_base64 = file_to_base64(image_file_path)
|
||||||
"base64://", ""
|
elif i.file and i.file.startswith("base64://"):
|
||||||
)
|
image_base64 = i.file
|
||||||
else:
|
else:
|
||||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
image_base64 = file_to_base64(i.file)
|
||||||
image_file_path = i.file
|
image_base64 = image_base64.removeprefix("base64://")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"qq_official 忽略 {i.type}")
|
logger.debug(f"qq_official 忽略 {i.type}")
|
||||||
return plain_text, image_base64, image_file_path
|
return plain_text, image_base64, image_file_path
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from astrbot.api.platform import (
|
|||||||
MessageType,
|
MessageType,
|
||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
)
|
)
|
||||||
|
from astrbot import logger
|
||||||
from astrbot.api.event import MessageChain
|
from astrbot.api.event import MessageChain
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
from astrbot.api.message_components import Image, Plain, At
|
from astrbot.api.message_components import Image, Plain, At
|
||||||
@@ -125,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"qq_official",
|
name="qq_official",
|
||||||
"QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -204,3 +206,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> botClient:
|
def get_client(self) -> botClient:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
await self.client.close()
|
||||||
|
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .qo_webhook_event import QQOfficialWebhookMessageEvent
|
|||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
from .qo_webhook_server import QQOfficialWebhook
|
from .qo_webhook_server import QQOfficialWebhook
|
||||||
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
# remove logger handler
|
# remove logger handler
|
||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
@@ -98,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"qq_official_webhook",
|
name="qq_official_webhook",
|
||||||
"QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
|
id=self.config.get("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
@@ -111,3 +113,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> botClient:
|
def get_client(self) -> botClient:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.webhook_helper.shutdown_event.set()
|
||||||
|
await self.client.close()
|
||||||
|
try:
|
||||||
|
await self.webhook_helper.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class QQOfficialWebhook:
|
|||||||
self.appid = config["appid"]
|
self.appid = config["appid"]
|
||||||
self.secret = config["secret"]
|
self.secret = config["secret"]
|
||||||
self.port = config.get("port", 6196)
|
self.port = config.get("port", 6196)
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
|
|
||||||
if isinstance(self.port, str):
|
if isinstance(self.port, str):
|
||||||
self.port = int(self.port)
|
self.port = int(self.port)
|
||||||
@@ -29,6 +30,7 @@ class QQOfficialWebhook:
|
|||||||
)
|
)
|
||||||
self.client = botpy_client
|
self.client = botpy_client
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
logger.info("正在登录到 QQ 官方机器人...")
|
logger.info("正在登录到 QQ 官方机器人...")
|
||||||
@@ -95,13 +97,14 @@ class QQOfficialWebhook:
|
|||||||
return {"opcode": 12}
|
return {"opcode": 12}
|
||||||
|
|
||||||
async def start_polling(self):
|
async def start_polling(self):
|
||||||
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。"
|
||||||
|
)
|
||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host="0.0.0.0",
|
host=self.callback_server_host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder,
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
while not self.event_queue.closed: # noqa: ASYNC110
|
await self.shutdown_event.wait()
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("qq_official_webhook 适配器已关闭。")
|
|
||||||
|
|||||||
@@ -1,33 +1,32 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from telegram import BotCommand, Update
|
||||||
|
from telegram.constants import ChatType
|
||||||
|
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
|
||||||
|
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||||
|
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
Platform,
|
|
||||||
AstrBotMessage,
|
AstrBotMessage,
|
||||||
MessageMember,
|
MessageMember,
|
||||||
PlatformMetadata,
|
|
||||||
MessageType,
|
MessageType,
|
||||||
)
|
Platform,
|
||||||
from astrbot.api.event import MessageChain
|
PlatformMetadata,
|
||||||
from astrbot.api.message_components import (
|
register_platform_adapter,
|
||||||
Plain,
|
|
||||||
Image,
|
|
||||||
Record,
|
|
||||||
File as AstrBotFile,
|
|
||||||
Video,
|
|
||||||
At,
|
|
||||||
)
|
)
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from astrbot.api.platform import register_platform_adapter
|
from astrbot.core.star.filter.command import CommandFilter
|
||||||
|
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import star_handlers_registry
|
||||||
|
|
||||||
from telegram import Update
|
|
||||||
from telegram.ext import ApplicationBuilder, ContextTypes, filters
|
|
||||||
from telegram.constants import ChatType
|
|
||||||
from telegram.ext import MessageHandler as TelegramMessageHandler
|
|
||||||
from .tg_event import TelegramPlatformEvent
|
from .tg_event import TelegramPlatformEvent
|
||||||
from astrbot.api import logger
|
|
||||||
from telegram.ext import ExtBot
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -59,6 +58,14 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
|
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
|
self.enable_command_register = self.config.get(
|
||||||
|
"telegram_command_register", True
|
||||||
|
)
|
||||||
|
self.enable_command_refresh = self.config.get(
|
||||||
|
"telegram_command_auto_refresh", True
|
||||||
|
)
|
||||||
|
self.last_command_hash = None
|
||||||
|
|
||||||
self.application = (
|
self.application = (
|
||||||
ApplicationBuilder()
|
ApplicationBuilder()
|
||||||
.token(self.config["telegram_token"])
|
.token(self.config["telegram_token"])
|
||||||
@@ -68,12 +75,14 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
message_handler = TelegramMessageHandler(
|
message_handler = TelegramMessageHandler(
|
||||||
filters=filters.ALL, # receive all messages
|
filters=filters.ALL, # receive all messages
|
||||||
callback=self.convert_message,
|
callback=self.message_handler,
|
||||||
)
|
)
|
||||||
self.application.add_handler(message_handler)
|
self.application.add_handler(message_handler)
|
||||||
self.client = self.application.bot
|
self.client = self.application.bot
|
||||||
logger.debug(f"Telegram base url: {self.client.base_url}")
|
logger.debug(f"Telegram base url: {self.client.base_url}")
|
||||||
|
|
||||||
|
self.scheduler = AsyncIOScheduler()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
self, session: MessageSesion, message_chain: MessageChain
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
@@ -87,94 +96,250 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
@override
|
@override
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"telegram",
|
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||||
"telegram 适配器",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await self.application.initialize()
|
await self.application.initialize()
|
||||||
await self.application.start()
|
await self.application.start()
|
||||||
|
|
||||||
|
if self.enable_command_register:
|
||||||
|
await self.register_commands()
|
||||||
|
|
||||||
|
if self.enable_command_refresh and self.enable_command_register:
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self.register_commands,
|
||||||
|
"interval",
|
||||||
|
seconds=self.config.get("telegram_command_register_interval", 300),
|
||||||
|
id="telegram_command_register",
|
||||||
|
misfire_grace_time=60,
|
||||||
|
)
|
||||||
|
self.scheduler.start()
|
||||||
|
|
||||||
queue = self.application.updater.start_polling()
|
queue = self.application.updater.start_polling()
|
||||||
logger.info("Telegram Platform Adapter is running.")
|
logger.info("Telegram Platform Adapter is running.")
|
||||||
await queue
|
await queue
|
||||||
|
|
||||||
|
async def register_commands(self):
|
||||||
|
"""收集所有注册的指令并注册到 Telegram"""
|
||||||
|
try:
|
||||||
|
commands = self.collect_commands()
|
||||||
|
|
||||||
|
if commands:
|
||||||
|
current_hash = hash(
|
||||||
|
tuple((cmd.command, cmd.description) for cmd in commands)
|
||||||
|
)
|
||||||
|
if current_hash == self.last_command_hash:
|
||||||
|
return
|
||||||
|
self.last_command_hash = current_hash
|
||||||
|
await self.client.delete_my_commands()
|
||||||
|
await self.client.set_my_commands(commands)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
|
||||||
|
|
||||||
|
def collect_commands(self) -> list[BotCommand]:
|
||||||
|
"""从注册的处理器中收集所有指令"""
|
||||||
|
command_dict = {}
|
||||||
|
skip_commands = {"start"}
|
||||||
|
|
||||||
|
for handler_md in star_handlers_registry:
|
||||||
|
handler_metadata = handler_md
|
||||||
|
if not star_map[handler_metadata.handler_module_path].activated:
|
||||||
|
continue
|
||||||
|
for event_filter in handler_metadata.event_filters:
|
||||||
|
cmd_info = self._extract_command_info(
|
||||||
|
event_filter, handler_metadata, skip_commands
|
||||||
|
)
|
||||||
|
if cmd_info:
|
||||||
|
cmd_name, description = cmd_info
|
||||||
|
command_dict.setdefault(cmd_name, description)
|
||||||
|
|
||||||
|
commands_a = sorted(command_dict.keys())
|
||||||
|
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_command_info(
|
||||||
|
event_filter, handler_metadata, skip_commands: set
|
||||||
|
) -> tuple[str, str] | None:
|
||||||
|
"""从事件过滤器中提取指令信息"""
|
||||||
|
cmd_name = None
|
||||||
|
is_group = False
|
||||||
|
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
|
||||||
|
if (
|
||||||
|
event_filter.parent_command_names
|
||||||
|
and event_filter.parent_command_names != [""]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
cmd_name = event_filter.command_name
|
||||||
|
elif isinstance(event_filter, CommandGroupFilter):
|
||||||
|
if event_filter.parent_group:
|
||||||
|
return None
|
||||||
|
cmd_name = event_filter.group_name
|
||||||
|
is_group = True
|
||||||
|
|
||||||
|
if not cmd_name or cmd_name in skip_commands:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
||||||
|
logger.debug(f"跳过无法注册的命令: {cmd_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build description.
|
||||||
|
description = handler_metadata.desc or (
|
||||||
|
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
|
||||||
|
)
|
||||||
|
if len(description) > 30:
|
||||||
|
description = description[:30] + "..."
|
||||||
|
return cmd_name, description
|
||||||
|
|
||||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
await context.bot.send_message(
|
await context.bot.send_message(
|
||||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
logger.debug(f"Telegram message: {update.message}")
|
||||||
|
abm = await self.convert_message(update, context)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
async def convert_message(
|
async def convert_message(
|
||||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE
|
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||||
) -> AstrBotMessage:
|
) -> AstrBotMessage:
|
||||||
|
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||||
|
|
||||||
|
@param update: Telegram 的 Update 对象。
|
||||||
|
@param context: Telegram 的 Context 对象。
|
||||||
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
|
"""
|
||||||
message = AstrBotMessage()
|
message = AstrBotMessage()
|
||||||
|
message.session_id = str(update.message.chat.id)
|
||||||
# 获得是群聊还是私聊
|
# 获得是群聊还是私聊
|
||||||
if update.effective_chat.type == ChatType.PRIVATE:
|
if update.message.chat.type == ChatType.PRIVATE:
|
||||||
message.type = MessageType.FRIEND_MESSAGE
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
else:
|
else:
|
||||||
message.type = MessageType.GROUP_MESSAGE
|
message.type = MessageType.GROUP_MESSAGE
|
||||||
message.group_id = update.effective_chat.id
|
message.group_id = str(update.message.chat.id)
|
||||||
|
if update.message.message_thread_id:
|
||||||
|
# Topic Group
|
||||||
|
message.group_id += "#" + str(update.message.message_thread_id)
|
||||||
|
message.session_id = message.group_id
|
||||||
|
|
||||||
message.message_id = str(update.message.message_id)
|
message.message_id = str(update.message.message_id)
|
||||||
message.session_id = str(update.effective_chat.id)
|
|
||||||
message.sender = MessageMember(
|
message.sender = MessageMember(
|
||||||
str(update.effective_user.id), update.effective_user.username
|
str(update.message.from_user.id), update.message.from_user.username
|
||||||
)
|
)
|
||||||
message.self_id = str(context.bot.username)
|
message.self_id = str(context.bot.username)
|
||||||
message.raw_message = update
|
message.raw_message = update
|
||||||
message.message_str = ""
|
message.message_str = ""
|
||||||
message.message = []
|
message.message = []
|
||||||
|
|
||||||
logger.debug(f"Telegram message: {update.message}")
|
if update.message.reply_to_message and not (
|
||||||
|
update.message.is_topic_message
|
||||||
|
and update.message.message_thread_id
|
||||||
|
== update.message.reply_to_message.message_id
|
||||||
|
):
|
||||||
|
# 获取回复消息
|
||||||
|
reply_update = Update(
|
||||||
|
update_id=1,
|
||||||
|
message=update.message.reply_to_message,
|
||||||
|
)
|
||||||
|
reply_abm = await self.convert_message(reply_update, context, False)
|
||||||
|
|
||||||
|
message.message.append(
|
||||||
|
Comp.Reply(
|
||||||
|
id=reply_abm.message_id,
|
||||||
|
chain=reply_abm.message,
|
||||||
|
sender_id=reply_abm.sender.user_id,
|
||||||
|
sender_nickname=reply_abm.sender.nickname,
|
||||||
|
time=reply_abm.timestamp,
|
||||||
|
message_str=reply_abm.message_str,
|
||||||
|
text=reply_abm.message_str,
|
||||||
|
qq=reply_abm.sender.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if update.message.text:
|
if update.message.text:
|
||||||
|
# 处理文本消息
|
||||||
plain_text = update.message.text
|
plain_text = update.message.text
|
||||||
|
|
||||||
|
# 群聊场景命令特殊处理
|
||||||
|
if plain_text.startswith("/"):
|
||||||
|
command_parts = plain_text.split(" ", 1)
|
||||||
|
if "@" in command_parts[0]:
|
||||||
|
command, bot_name = command_parts[0].split("@")
|
||||||
|
if bot_name == self.client.username:
|
||||||
|
plain_text = command + (
|
||||||
|
f" {command_parts[1]}" if len(command_parts) > 1 else ""
|
||||||
|
)
|
||||||
|
|
||||||
if update.message.entities:
|
if update.message.entities:
|
||||||
for entity in update.message.entities:
|
for entity in update.message.entities:
|
||||||
if entity.type == "mention":
|
if entity.type == "mention":
|
||||||
name = plain_text[
|
name = plain_text[
|
||||||
entity.offset + 1 : entity.offset + entity.length
|
entity.offset + 1 : entity.offset + entity.length
|
||||||
]
|
]
|
||||||
message.message.append(At(qq=name, name=name))
|
message.message.append(Comp.At(qq=name, name=name))
|
||||||
plain_text = (
|
# 如果mention是当前bot则移除;否则保留
|
||||||
plain_text[: entity.offset]
|
if name.lower() == context.bot.username.lower():
|
||||||
+ plain_text[entity.offset + entity.length :]
|
plain_text = (
|
||||||
)
|
plain_text[: entity.offset]
|
||||||
|
+ plain_text[entity.offset + entity.length :]
|
||||||
|
)
|
||||||
|
|
||||||
if plain_text:
|
if plain_text:
|
||||||
message.message.append(Plain(plain_text))
|
message.message.append(Comp.Plain(plain_text))
|
||||||
message.message_str = plain_text
|
message.message_str = plain_text
|
||||||
|
|
||||||
if message.message_str == "/start":
|
if message.message_str.strip() == "/start":
|
||||||
await self.start(update, context)
|
await self.start(update, context)
|
||||||
return
|
return
|
||||||
|
|
||||||
elif update.message.voice:
|
elif update.message.voice:
|
||||||
file = await update.message.voice.get_file()
|
file = await update.message.voice.get_file()
|
||||||
message.message = [
|
message.message = [
|
||||||
Record(file=file.file_path, url=file.file_path),
|
Comp.Record(file=file.file_path, url=file.file_path),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif update.message.photo:
|
elif update.message.photo:
|
||||||
photo = update.message.photo[-1] # get the largest photo
|
photo = update.message.photo[-1] # get the largest photo
|
||||||
file = await photo.get_file()
|
file = await photo.get_file()
|
||||||
message.message.append(Image(file=file.file_path, url=file.file_path))
|
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||||
|
if update.message.caption:
|
||||||
|
message.message_str = update.message.caption
|
||||||
|
message.message.append(Comp.Plain(message.message_str))
|
||||||
|
if update.message.caption_entities:
|
||||||
|
for entity in update.message.caption_entities:
|
||||||
|
if entity.type == "mention":
|
||||||
|
name = message.message_str[
|
||||||
|
entity.offset + 1 : entity.offset + entity.length
|
||||||
|
]
|
||||||
|
message.message.append(Comp.At(qq=name, name=name))
|
||||||
|
|
||||||
|
elif update.message.sticker:
|
||||||
|
# 将sticker当作图片处理
|
||||||
|
file = await update.message.sticker.get_file()
|
||||||
|
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||||
|
if update.message.sticker.emoji:
|
||||||
|
sticker_text = f"Sticker: {update.message.sticker.emoji}"
|
||||||
|
message.message_str = sticker_text
|
||||||
|
message.message.append(Comp.Plain(sticker_text))
|
||||||
|
|
||||||
elif update.message.document:
|
elif update.message.document:
|
||||||
file = await update.message.document.get_file()
|
file = await update.message.document.get_file()
|
||||||
message.message = [
|
message.message = [
|
||||||
AstrBotFile(
|
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||||
file=file.file_path, name=update.message.document.file_name
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
elif update.message.video:
|
elif update.message.video:
|
||||||
file = await update.message.video.get_file()
|
file = await update.message.video.get_file()
|
||||||
message.message = [
|
message.message = [
|
||||||
Video(file=file.file_path, path=file.file_path),
|
Comp.Video(file=file.file_path, path=file.file_path),
|
||||||
]
|
]
|
||||||
|
|
||||||
await self.handle_msg(message)
|
return message
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
message_event = TelegramPlatformEvent(
|
message_event = TelegramPlatformEvent(
|
||||||
@@ -188,3 +353,21 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> ExtBot:
|
def get_client(self) -> ExtBot:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
try:
|
||||||
|
if self.scheduler.running:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
|
||||||
|
await self.application.stop()
|
||||||
|
|
||||||
|
if self.enable_command_register:
|
||||||
|
await self.client.delete_my_commands()
|
||||||
|
|
||||||
|
# 保险起见先判断是否存在updater对象
|
||||||
|
if self.application.updater is not None:
|
||||||
|
await self.application.updater.stop()
|
||||||
|
|
||||||
|
logger.info("Telegram 适配器已被优雅地关闭")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Telegram 适配器关闭时出错: {e}")
|
||||||
|
|||||||
@@ -1,10 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
import telegramify_markdown
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||||
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
Reply,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Record,
|
||||||
|
)
|
||||||
from telegram.ext import ExtBot
|
from telegram.ext import ExtBot
|
||||||
|
from astrbot.core.utils.io import download_file
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
|
||||||
class TelegramPlatformEvent(AstrMessageEvent):
|
class TelegramPlatformEvent(AstrMessageEvent):
|
||||||
|
# Telegram 的最大消息长度限制
|
||||||
|
MAX_MESSAGE_LENGTH = 4096
|
||||||
|
|
||||||
|
SPLIT_PATTERNS = {
|
||||||
|
"paragraph": re.compile(r"\n\n"),
|
||||||
|
"line": re.compile(r"\n"),
|
||||||
|
"sentence": re.compile(r"[.!?。!?]"),
|
||||||
|
"word": re.compile(r"\s"),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_str: str,
|
message_str: str,
|
||||||
@@ -16,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
@staticmethod
|
def _split_message(self, text: str) -> list[str]:
|
||||||
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
|
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
while text:
|
||||||
|
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||||
|
chunks.append(text)
|
||||||
|
break
|
||||||
|
|
||||||
|
split_point = self.MAX_MESSAGE_LENGTH
|
||||||
|
segment = text[: self.MAX_MESSAGE_LENGTH]
|
||||||
|
|
||||||
|
for _, pattern in self.SPLIT_PATTERNS.items():
|
||||||
|
if matches := list(pattern.finditer(segment)):
|
||||||
|
last_match = matches[-1]
|
||||||
|
split_point = last_match.end()
|
||||||
|
break
|
||||||
|
|
||||||
|
chunks.append(text[:split_point])
|
||||||
|
text = text[split_point:].lstrip()
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
async def send_with_client(
|
||||||
|
self, client: ExtBot, message: MessageChain, user_name: str
|
||||||
|
):
|
||||||
image_path = None
|
image_path = None
|
||||||
|
|
||||||
has_reply = False
|
has_reply = False
|
||||||
@@ -31,36 +80,51 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
at_user_id = i.name
|
at_user_id = i.name
|
||||||
|
|
||||||
at_flag = False
|
at_flag = False
|
||||||
|
message_thread_id = None
|
||||||
|
if "#" in user_name:
|
||||||
|
# it's a supergroup chat with message_thread_id
|
||||||
|
user_name, message_thread_id = user_name.split("#")
|
||||||
for i in message.chain:
|
for i in message.chain:
|
||||||
payload = {
|
payload = {
|
||||||
"chat_id": user_name,
|
"chat_id": user_name,
|
||||||
}
|
}
|
||||||
if has_reply:
|
if has_reply:
|
||||||
payload["reply_to_message_id"] = reply_message_id
|
payload["reply_to_message_id"] = reply_message_id
|
||||||
|
if message_thread_id:
|
||||||
|
payload["message_thread_id"] = message_thread_id
|
||||||
|
|
||||||
if isinstance(i, Plain):
|
if isinstance(i, Plain):
|
||||||
if at_user_id and not at_flag:
|
if at_user_id and not at_flag:
|
||||||
i.text = f"@{at_user_id} " + i.text
|
i.text = f"@{at_user_id} {i.text}"
|
||||||
at_flag = True
|
at_flag = True
|
||||||
await client.send_message(text=i.text, **payload)
|
chunks = self._split_message(i.text)
|
||||||
|
for chunk in chunks:
|
||||||
|
try:
|
||||||
|
md_text = telegramify_markdown.markdownify(
|
||||||
|
chunk, max_line_length=None, normalize_whitespace=False
|
||||||
|
)
|
||||||
|
await client.send_message(
|
||||||
|
text=md_text, parse_mode="MarkdownV2", **payload
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"MarkdownV2 send failed: {e}. Using plain text instead."
|
||||||
|
)
|
||||||
|
await client.send_message(text=chunk, **payload)
|
||||||
elif isinstance(i, Image):
|
elif isinstance(i, Image):
|
||||||
if i.path:
|
image_path = await i.convert_to_file_path()
|
||||||
image_path = i.path
|
await client.send_photo(photo=image_path, **payload)
|
||||||
else:
|
|
||||||
image_path = i.file
|
|
||||||
|
|
||||||
if image_path.startswith("base64://"):
|
|
||||||
import base64
|
|
||||||
|
|
||||||
base64_data = image_path[9:]
|
|
||||||
image_bytes = base64.b64decode(base64_data)
|
|
||||||
await client.send_photo(photo=image_bytes, **payload)
|
|
||||||
else:
|
|
||||||
await client.send_photo(photo=image_path, **payload)
|
|
||||||
elif isinstance(i, File):
|
elif isinstance(i, File):
|
||||||
|
if i.file.startswith("https://"):
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
path = os.path.join(temp_dir, i.name)
|
||||||
|
await download_file(i.file, path)
|
||||||
|
i.file = path
|
||||||
|
|
||||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
await client.send_voice(voice=i.file, **payload)
|
path = await i.convert_to_file_path()
|
||||||
|
await client.send_voice(voice=path, **payload)
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
@@ -68,3 +132,110 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
message_thread_id = None
|
||||||
|
|
||||||
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
|
user_name = self.message_obj.group_id
|
||||||
|
else:
|
||||||
|
user_name = self.get_sender_id()
|
||||||
|
|
||||||
|
if "#" in user_name:
|
||||||
|
# it's a supergroup chat with message_thread_id
|
||||||
|
user_name, message_thread_id = user_name.split("#")
|
||||||
|
payload = {
|
||||||
|
"chat_id": user_name,
|
||||||
|
}
|
||||||
|
if message_thread_id:
|
||||||
|
payload["reply_to_message_id"] = message_thread_id
|
||||||
|
|
||||||
|
delta = ""
|
||||||
|
current_content = ""
|
||||||
|
message_id = None
|
||||||
|
last_edit_time = 0 # 上次编辑消息的时间
|
||||||
|
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
# 处理消息链中的每个组件
|
||||||
|
for i in chain.chain:
|
||||||
|
if isinstance(i, Plain):
|
||||||
|
delta += i.text
|
||||||
|
elif isinstance(i, Image):
|
||||||
|
image_path = await i.convert_to_file_path()
|
||||||
|
await self.client.send_photo(photo=image_path, **payload)
|
||||||
|
continue
|
||||||
|
elif isinstance(i, File):
|
||||||
|
if i.file.startswith("https://"):
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
path = os.path.join(temp_dir, i.name)
|
||||||
|
await download_file(i.file, path)
|
||||||
|
i.file = path
|
||||||
|
|
||||||
|
await self.client.send_document(
|
||||||
|
document=i.file, filename=i.name, **payload
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif isinstance(i, Record):
|
||||||
|
path = await i.convert_to_file_path()
|
||||||
|
await self.client.send_voice(voice=path, **payload)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Plain
|
||||||
|
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
|
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
|
||||||
|
if time_since_last_edit >= throttle_interval:
|
||||||
|
# 编辑消息
|
||||||
|
try:
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
current_content = delta
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
|
last_edit_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
) # 更新上次编辑的时间
|
||||||
|
else:
|
||||||
|
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||||
|
try:
|
||||||
|
msg = await self.client.send_message(text=delta, **payload)
|
||||||
|
current_content = delta
|
||||||
|
delta = ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
|
message_id = msg.message_id
|
||||||
|
last_edit_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
) # 记录初始消息发送时间
|
||||||
|
|
||||||
|
try:
|
||||||
|
if delta and current_content != delta:
|
||||||
|
try:
|
||||||
|
markdown_text = telegramify_markdown.markdownify(
|
||||||
|
delta, max_line_length=None, normalize_whitespace=False
|
||||||
|
)
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=markdown_text,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
parse_mode="MarkdownV2",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||||
|
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue
|
|||||||
from .webchat_event import WebChatMessageEvent
|
from .webchat_event import WebChatMessageEvent
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from ...register import register_platform_adapter
|
from ...register import register_platform_adapter
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
|
||||||
class QueueListener:
|
class QueueListener:
|
||||||
@@ -40,11 +41,11 @@ class WebChatAdapter(Platform):
|
|||||||
self.config = platform_config
|
self.config = platform_config
|
||||||
self.settings = platform_settings
|
self.settings = platform_settings
|
||||||
self.unique_session = platform_settings["unique_session"]
|
self.unique_session = platform_settings["unique_session"]
|
||||||
self.imgs_dir = "data/webchat/imgs"
|
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||||
|
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
"webchat",
|
name="webchat", description="webchat", id=self.config.get("id")
|
||||||
"webchat",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_by_session(
|
async def send_by_session(
|
||||||
@@ -119,3 +120,7 @@ class WebChatAdapter(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.commit_event(message_event)
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
# Do nothing
|
||||||
|
pass
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ import uuid
|
|||||||
import base64
|
import base64
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import Plain, Image
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core import web_chat_back_queue
|
from astrbot.core import web_chat_back_queue
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
imgs_dir = "data/webchat/imgs"
|
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||||
|
|
||||||
|
|
||||||
class WebChatMessageEvent(AstrMessageEvent):
|
class WebChatMessageEvent(AstrMessageEvent):
|
||||||
@@ -16,16 +17,26 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
os.makedirs(imgs_dir, exist_ok=True)
|
os.makedirs(imgs_dir, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _send(message: MessageChain, session_id: str):
|
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||||
if not message:
|
if not message:
|
||||||
web_chat_back_queue.put_nowait(None)
|
await web_chat_back_queue.put(
|
||||||
return
|
{"type": "end", "data": "", "streaming": False}
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
cid = session_id.split("!")[-1]
|
cid = session_id.split("!")[-1]
|
||||||
|
data = ""
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
data = comp.text
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "plain",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
# save image to local
|
# save image to local
|
||||||
filename = str(uuid.uuid4()) + ".jpg"
|
filename = str(uuid.uuid4()) + ".jpg"
|
||||||
@@ -46,11 +57,69 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
with open(comp.file, "rb") as f2:
|
with open(comp.file, "rb") as f2:
|
||||||
f.write(f2.read())
|
f.write(f2.read())
|
||||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
data = f"[IMAGE]{filename}"
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Record):
|
||||||
|
# save record to local
|
||||||
|
filename = str(uuid.uuid4()) + ".wav"
|
||||||
|
path = os.path.join(imgs_dir, filename)
|
||||||
|
if comp.file and comp.file.startswith("file:///"):
|
||||||
|
ph = comp.file[8:]
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
with open(ph, "rb") as f2:
|
||||||
|
f.write(f2.read())
|
||||||
|
elif comp.file and comp.file.startswith("http"):
|
||||||
|
await download_image_by_url(comp.file, path=path)
|
||||||
|
else:
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
with open(comp.file, "rb") as f2:
|
||||||
|
f.write(f2.read())
|
||||||
|
data = f"[RECORD]{filename}"
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "record",
|
||||||
|
"cid": cid,
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"webchat 忽略: {comp.type}")
|
logger.debug(f"webchat 忽略: {comp.type}")
|
||||||
web_chat_back_queue.put_nowait(None)
|
|
||||||
|
return data
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": "",
|
||||||
|
"streaming": False,
|
||||||
|
"cid": self.session_id.split("!")[-1],
|
||||||
|
}
|
||||||
|
)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
final_data = ""
|
||||||
|
async for chain in generator:
|
||||||
|
final_data += await WebChatMessageEvent._send(
|
||||||
|
chain, session_id=self.session_id, streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": final_data,
|
||||||
|
"streaming": True,
|
||||||
|
"cid": self.session_id.split("!")[-1],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -0,0 +1,707 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.api.message_components import Plain, Image
|
||||||
|
from astrbot.api.platform import Platform, PlatformMetadata
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.platform.astrbot_message import (
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
)
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from .wechatpadpro_message_event import WeChatPadProMessageEvent
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||||
|
class WeChatPadProAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self._shutdown_event = None
|
||||||
|
self.wxnewpass = None
|
||||||
|
self.config = platform_config
|
||||||
|
self.settings = platform_settings
|
||||||
|
self.unique_session = platform_settings.get("unique_session", False)
|
||||||
|
|
||||||
|
self.metadata = PlatformMetadata(
|
||||||
|
name="wechatpadpro",
|
||||||
|
description="WeChatPadPro 消息平台适配器",
|
||||||
|
id=self.config.get("id", "wechatpadpro"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存配置信息
|
||||||
|
self.admin_key = self.config.get("admin_key")
|
||||||
|
self.host = self.config.get("host")
|
||||||
|
self.port = self.config.get("port")
|
||||||
|
self.active_mesasge_poll: bool = self.config.get(
|
||||||
|
"wpp_active_message_poll", False
|
||||||
|
)
|
||||||
|
self.active_message_poll_interval: int = self.config.get(
|
||||||
|
"wpp_active_message_poll_interval", 5
|
||||||
|
)
|
||||||
|
self.base_url = f"http://{self.host}:{self.port}"
|
||||||
|
self.auth_key = None # 用于保存生成的授权码
|
||||||
|
self.wxid = None # 用于保存登录成功后的 wxid
|
||||||
|
self.credentials_file = os.path.join(
|
||||||
|
get_astrbot_data_path(), "wechatpadpro_credentials.json"
|
||||||
|
) # 持久化文件路径
|
||||||
|
self.ws_handle_task = None
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""
|
||||||
|
启动平台适配器的运行实例。
|
||||||
|
"""
|
||||||
|
logger.info("WeChatPadPro 适配器正在启动...")
|
||||||
|
|
||||||
|
if loaded_credentials := self.load_credentials():
|
||||||
|
self.auth_key = loaded_credentials.get("auth_key")
|
||||||
|
self.wxid = loaded_credentials.get("wxid")
|
||||||
|
|
||||||
|
isLoginIn = await self.check_online_status()
|
||||||
|
|
||||||
|
# 检查在线状态
|
||||||
|
if self.auth_key and isLoginIn:
|
||||||
|
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
|
||||||
|
# 如果在线,连接 WebSocket 接收消息
|
||||||
|
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||||
|
else:
|
||||||
|
# 1. 生成授权码
|
||||||
|
if not self.auth_key:
|
||||||
|
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
|
||||||
|
await self.generate_auth_key()
|
||||||
|
|
||||||
|
# 2. 获取登录二维码
|
||||||
|
if not isLoginIn:
|
||||||
|
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
|
||||||
|
qr_code_url = await self.get_login_qr_code()
|
||||||
|
|
||||||
|
if qr_code_url:
|
||||||
|
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
|
||||||
|
else:
|
||||||
|
logger.error("无法获取登录二维码。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. 检测扫码状态
|
||||||
|
login_successful = await self.check_login_status()
|
||||||
|
|
||||||
|
if login_successful:
|
||||||
|
logger.info("登录成功,WeChatPadPro适配器已连接。")
|
||||||
|
else:
|
||||||
|
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
|
||||||
|
await self.terminate()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 登录成功后,连接 WebSocket 接收消息
|
||||||
|
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||||
|
|
||||||
|
self._shutdown_event = asyncio.Event()
|
||||||
|
await self._shutdown_event.wait()
|
||||||
|
logger.info("WeChatPadPro 适配器已停止。")
|
||||||
|
|
||||||
|
def load_credentials(self):
|
||||||
|
"""
|
||||||
|
从文件中加载 auth_key 和 wxid。
|
||||||
|
"""
|
||||||
|
if os.path.exists(self.credentials_file):
|
||||||
|
try:
|
||||||
|
with open(self.credentials_file, "r") as f:
|
||||||
|
credentials = json.load(f)
|
||||||
|
logger.info("成功加载 WeChatPadPro 凭据。")
|
||||||
|
return credentials
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载 WeChatPadPro 凭据失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_credentials(self):
|
||||||
|
"""
|
||||||
|
将 auth_key 和 wxid 保存到文件。
|
||||||
|
"""
|
||||||
|
credentials = {
|
||||||
|
"auth_key": self.auth_key,
|
||||||
|
"wxid": self.wxid,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
# 确保数据目录存在
|
||||||
|
data_dir = os.path.dirname(self.credentials_file)
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
with open(self.credentials_file, "w") as f:
|
||||||
|
json.dump(credentials, f)
|
||||||
|
logger.info("成功保存 WeChatPadPro 凭据。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
|
||||||
|
|
||||||
|
async def check_online_status(self):
|
||||||
|
"""
|
||||||
|
检查 WeChatPadPro 设备是否在线。
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/login/GetLoginStatus"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.get(url, params=params) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
# 根据提供的在线接口返回示例,成功状态码是 200,loginState 为 1 表示在线
|
||||||
|
if response.status == 200 and response_data.get("Code") == 200:
|
||||||
|
login_state = response_data.get("Data", {}).get("loginState")
|
||||||
|
if login_state == 1:
|
||||||
|
logger.info("WeChatPadPro 设备当前在线。")
|
||||||
|
return True
|
||||||
|
# login_state == 3 为离线状态
|
||||||
|
elif login_state == 3:
|
||||||
|
logger.info(
|
||||||
|
"WeChatPadPro 设备不在线。"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"未知的在线状态: {login_state:}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
# Code == 300 为微信退出状态。
|
||||||
|
elif response.status == 200 and response_data.get("Code") == 300:
|
||||||
|
logger.info(
|
||||||
|
"WeChatPadPro 设备已退出。"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"检查在线状态失败: {response.status}, {response_data}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查在线状态时发生错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def generate_auth_key(self):
|
||||||
|
"""
|
||||||
|
生成授权码。
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||||
|
params = {"key": self.admin_key}
|
||||||
|
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
# 修正成功判断条件和授权码提取路径
|
||||||
|
if response.status == 200 and response_data.get("Code") == 200:
|
||||||
|
# 授权码在 Data 字段的列表中
|
||||||
|
if (
|
||||||
|
response_data.get("Data")
|
||||||
|
and isinstance(response_data["Data"], list)
|
||||||
|
and len(response_data["Data"]) > 0
|
||||||
|
):
|
||||||
|
self.auth_key = response_data["Data"][0]
|
||||||
|
logger.info("成功获取授权码")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"生成授权码成功但未找到授权码: {response_data}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"生成授权码失败: {response.status}, {response_data}"
|
||||||
|
)
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成授权码时发生错误: {e}")
|
||||||
|
|
||||||
|
async def get_login_qr_code(self):
|
||||||
|
"""
|
||||||
|
获取登录二维码地址。
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
payload = {} # 根据文档,这个接口的 body 可以为空
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
# 修正成功判断条件和数据提取路径
|
||||||
|
if response.status == 200 and response_data.get("Code") == 200:
|
||||||
|
# 二维码地址在 Data.QrCodeUrl 字段中
|
||||||
|
if response_data.get("Data") and response_data["Data"].get(
|
||||||
|
"QrCodeUrl"
|
||||||
|
):
|
||||||
|
return response_data["Data"]["QrCodeUrl"]
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"获取登录二维码成功但未找到二维码地址: {response_data}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"获取登录二维码失败: {response.status}, {response_data}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取登录二维码时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def check_login_status(self):
|
||||||
|
"""
|
||||||
|
循环检测扫码状态。
|
||||||
|
尝试 6 次后跳出循环,添加倒计时。
|
||||||
|
返回 True 如果登录成功,否则返回 False。
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/login/CheckLoginStatus"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
|
||||||
|
attempts = 0 # 初始化尝试次数
|
||||||
|
max_attempts = 36 # 最大尝试次数
|
||||||
|
countdown = 180 # 倒计时时长
|
||||||
|
logger.info(f"请在 {countdown} 秒内扫码登录。")
|
||||||
|
while attempts < max_attempts:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.get(url, params=params) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
# 成功判断条件和数据提取路径
|
||||||
|
if response.status == 200 and response_data.get("Code") == 200:
|
||||||
|
if (
|
||||||
|
response_data.get("Data")
|
||||||
|
and response_data["Data"].get("state") is not None
|
||||||
|
):
|
||||||
|
status = response_data["Data"]["state"]
|
||||||
|
logger.info(
|
||||||
|
f"第 {attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}秒"
|
||||||
|
)
|
||||||
|
if status == 2: # 状态 2 表示登录成功
|
||||||
|
self.wxid = response_data["Data"].get("wxid")
|
||||||
|
self.wxnewpass = response_data["Data"].get(
|
||||||
|
"wxnewpass"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}"
|
||||||
|
)
|
||||||
|
self.save_credentials() # 登录成功后保存凭据
|
||||||
|
return True
|
||||||
|
elif status == -2: # 二维码过期
|
||||||
|
logger.error("二维码已过期,请重新获取。")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"检测登录状态成功但未找到登录状态: {response_data}"
|
||||||
|
)
|
||||||
|
elif response_data.get("Code") == 300:
|
||||||
|
# "不存在状态"
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"检测登录状态失败: {response.status}, {response_data}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
attempts += 1
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检测登录状态时发生错误: {e}")
|
||||||
|
attempts += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
attempts += 1
|
||||||
|
await asyncio.sleep(5) # 每隔5秒检测一次
|
||||||
|
logger.warning("登录检测超过最大尝试次数,退出检测。")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def connect_websocket(self):
|
||||||
|
"""
|
||||||
|
建立 WebSocket 连接并处理接收到的消息。
|
||||||
|
"""
|
||||||
|
os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}"
|
||||||
|
ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}"
|
||||||
|
logger.info(
|
||||||
|
f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***"
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
async with websockets.connect(ws_url) as websocket:
|
||||||
|
logger.info("WebSocket 连接成功。")
|
||||||
|
# 设置空闲超时重连
|
||||||
|
wait_time = (
|
||||||
|
self.active_message_poll_interval
|
||||||
|
if self.active_mesasge_poll
|
||||||
|
else 120
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
websocket.recv(), timeout=wait_time
|
||||||
|
)
|
||||||
|
# logger.debug(message) # 不显示原始消息内容
|
||||||
|
asyncio.create_task(self.handle_websocket_message(message))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"WebSocket 连接空闲超过 {wait_time} s")
|
||||||
|
break
|
||||||
|
except websockets.exceptions.ConnectionClosedOK:
|
||||||
|
logger.info("WebSocket 连接正常关闭。")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
async def handle_websocket_message(self, message: str):
|
||||||
|
"""
|
||||||
|
处理从 WebSocket 接收到的消息。
|
||||||
|
"""
|
||||||
|
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||||
|
try:
|
||||||
|
message_data = json.loads(message)
|
||||||
|
if (
|
||||||
|
message_data.get("msg_id") is not None
|
||||||
|
and message_data.get("from_user_name") is not None
|
||||||
|
):
|
||||||
|
abm = await self.convert_message(message_data)
|
||||||
|
if abm:
|
||||||
|
# 创建 WeChatPadProMessageEvent 实例
|
||||||
|
message_event = WeChatPadProMessageEvent(
|
||||||
|
message_str=abm.message_str,
|
||||||
|
message_obj=abm,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=abm.session_id,
|
||||||
|
# 传递适配器实例,以便在事件中调用 send 方法
|
||||||
|
adapter=self,
|
||||||
|
)
|
||||||
|
# 提交事件到事件队列
|
||||||
|
self.commit_event(message_event)
|
||||||
|
else:
|
||||||
|
logger.warning(f"收到未知结构的 WebSocket 消息: {message_data}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"无法解析 WebSocket 消息为 JSON: {message}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||||
|
|
||||||
|
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||||
|
"""
|
||||||
|
将 WeChatPadPro 原始消息转换为 AstrBotMessage。
|
||||||
|
"""
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.raw_message = raw_message
|
||||||
|
abm.message_id = str(raw_message.get("msg_id"))
|
||||||
|
abm.timestamp = raw_message.get("create_time")
|
||||||
|
abm.self_id = self.wxid
|
||||||
|
|
||||||
|
if int(time.time()) - abm.timestamp > 180:
|
||||||
|
logger.warning(
|
||||||
|
f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}。"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||||
|
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||||
|
content = raw_message.get("content", {}).get("str", "")
|
||||||
|
push_content = raw_message.get("push_content", "")
|
||||||
|
msg_type = raw_message.get("msg_type")
|
||||||
|
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.message = []
|
||||||
|
|
||||||
|
# 如果是机器人自己发送的消息、回显消息或系统消息,忽略
|
||||||
|
if from_user_name == self.wxid:
|
||||||
|
logger.info("忽略来自自己的消息。")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if from_user_name in ["weixin", "newsapp", "newsapp_wechat"]:
|
||||||
|
logger.info("忽略来自微信团队的消息。")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 先判断群聊/私聊并设置基本属性
|
||||||
|
if await self._process_chat_type(
|
||||||
|
abm, raw_message, from_user_name, to_user_name, content, push_content
|
||||||
|
):
|
||||||
|
# 再根据消息类型处理消息内容
|
||||||
|
await self._process_message_content(abm, raw_message, msg_type, content)
|
||||||
|
|
||||||
|
return abm
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _process_chat_type(
|
||||||
|
self,
|
||||||
|
abm: AstrBotMessage,
|
||||||
|
raw_message: dict,
|
||||||
|
from_user_name: str,
|
||||||
|
to_user_name: str,
|
||||||
|
content: str,
|
||||||
|
push_content: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。
|
||||||
|
"""
|
||||||
|
if from_user_name == "weixin":
|
||||||
|
return False
|
||||||
|
if "@chatroom" in from_user_name:
|
||||||
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
|
abm.group_id = from_user_name
|
||||||
|
|
||||||
|
parts = content.split(":\n", 1)
|
||||||
|
sender_wxid = parts[0] if len(parts) == 2 else ""
|
||||||
|
abm.sender = MessageMember(user_id=sender_wxid, nickname="")
|
||||||
|
|
||||||
|
# 获取群聊发送者的nickname
|
||||||
|
if sender_wxid:
|
||||||
|
accurate_nickname = await self._get_group_member_nickname(
|
||||||
|
abm.group_id, sender_wxid
|
||||||
|
)
|
||||||
|
if accurate_nickname:
|
||||||
|
abm.sender.nickname = accurate_nickname
|
||||||
|
|
||||||
|
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
|
||||||
|
if self.unique_session:
|
||||||
|
abm.session_id = f"{from_user_name}_{to_user_name}"
|
||||||
|
else:
|
||||||
|
abm.session_id = from_user_name
|
||||||
|
else:
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.group_id = ""
|
||||||
|
nick_name = ""
|
||||||
|
if push_content and " : " in push_content:
|
||||||
|
nick_name = push_content.split(" : ")[0]
|
||||||
|
abm.sender = MessageMember(user_id=from_user_name, nickname=nick_name)
|
||||||
|
abm.session_id = from_user_name
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _get_group_member_nickname(
|
||||||
|
self, group_id: str, member_wxid: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
通过接口获取群成员的昵称。
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/group/GetChatroomMemberDetail"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
payload = {
|
||||||
|
"ChatRoomName": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
if response.status == 200 and response_data.get("Code") == 200:
|
||||||
|
# 从返回数据中查找对应成员的昵称
|
||||||
|
member_list = (
|
||||||
|
response_data.get("Data", {})
|
||||||
|
.get("member_data", {})
|
||||||
|
.get("chatroom_member_list", [])
|
||||||
|
)
|
||||||
|
for member in member_list:
|
||||||
|
if member.get("user_name") == member_wxid:
|
||||||
|
return member.get("nick_name")
|
||||||
|
logger.warning(
|
||||||
|
f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"获取群成员详情失败: {response.status}, {response_data}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取群成员详情时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _download_raw_image(
|
||||||
|
self, from_user_name: str, to_user_name: str, msg_id: int
|
||||||
|
):
|
||||||
|
"""下载原始图片。"""
|
||||||
|
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
payload = {
|
||||||
|
"CompressType": 0,
|
||||||
|
"FromUserName": from_user_name,
|
||||||
|
"MsgId": msg_id,
|
||||||
|
"Section": {"DataLen": 61440, "StartPos": 0},
|
||||||
|
"ToUserName": to_user_name,
|
||||||
|
"TotalLen": 0,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
else:
|
||||||
|
logger.error(f"下载图片失败: {response.status}")
|
||||||
|
return None
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载图片时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _process_message_content(
|
||||||
|
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。
|
||||||
|
"""
|
||||||
|
if msg_type == 1: # 文本消息
|
||||||
|
abm.message_str = content
|
||||||
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
|
parts = content.split(":\n", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
abm.message_str = parts[1]
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
else:
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
else: # 私聊消息
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
elif msg_type == 3:
|
||||||
|
# 图片消息
|
||||||
|
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||||
|
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||||
|
msg_id = raw_message.get("msg_id")
|
||||||
|
image_resp = await self._download_raw_image(
|
||||||
|
from_user_name, to_user_name, msg_id
|
||||||
|
)
|
||||||
|
image_bs64_data = (
|
||||||
|
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||||
|
)
|
||||||
|
if image_bs64_data:
|
||||||
|
abm.message.append(Image.fromBase64(image_bs64_data))
|
||||||
|
elif msg_type == 47:
|
||||||
|
# 视频消息 (注意:表情消息也是 47,需要区分)
|
||||||
|
logger.warning("收到视频消息,待实现。")
|
||||||
|
elif msg_type == 50:
|
||||||
|
# 语音/视频
|
||||||
|
logger.warning("收到语音/视频消息,待实现。")
|
||||||
|
elif msg_type == 49:
|
||||||
|
# 引用消息
|
||||||
|
logger.warning("收到引用消息,待实现。")
|
||||||
|
else:
|
||||||
|
logger.warning(f"收到未处理的消息类型: {msg_type}。")
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""
|
||||||
|
终止一个平台的运行实例。
|
||||||
|
"""
|
||||||
|
logger.info("终止 WeChatPadPro 适配器。")
|
||||||
|
try:
|
||||||
|
if self.ws_handle_task:
|
||||||
|
self.ws_handle_task.cancel()
|
||||||
|
self._shutdown_event.set()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
"""
|
||||||
|
得到一个平台的元数据。
|
||||||
|
"""
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
dummy_message_obj = AstrBotMessage()
|
||||||
|
dummy_message_obj.session_id = session.session_id
|
||||||
|
# 根据 session_id 判断消息类型
|
||||||
|
if "@chatroom" in session.session_id:
|
||||||
|
dummy_message_obj.type = MessageType.GROUP_MESSAGE
|
||||||
|
dummy_message_obj.group_id = session.session_id
|
||||||
|
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||||
|
else:
|
||||||
|
dummy_message_obj.type = MessageType.FRIEND_MESSAGE
|
||||||
|
dummy_message_obj.group_id = ""
|
||||||
|
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||||
|
sending_event = WeChatPadProMessageEvent(
|
||||||
|
message_str="",
|
||||||
|
message_obj=dummy_message_obj,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=session.session_id,
|
||||||
|
adapter=self,
|
||||||
|
)
|
||||||
|
# 调用实例方法 send
|
||||||
|
await sending_event.send(message_chain)
|
||||||
|
|
||||||
|
async def get_contact_list(self):
|
||||||
|
"""
|
||||||
|
获取联系人列表。
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/friend/GetContactList"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
logger.error(f"获取联系人列表失败: {response.status}")
|
||||||
|
return None
|
||||||
|
result = await response.json()
|
||||||
|
if result.get("Code") == 200 and result.get("Data"):
|
||||||
|
contact_list = (
|
||||||
|
result.get("Data", {})
|
||||||
|
.get("ContactList", {})
|
||||||
|
.get("contactUsernameList", [])
|
||||||
|
)
|
||||||
|
return contact_list
|
||||||
|
else:
|
||||||
|
logger.error(f"获取联系人列表失败: {result}")
|
||||||
|
return None
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取联系人列表时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_contact_details_list(
|
||||||
|
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
获取联系人详情列表。
|
||||||
|
"""
|
||||||
|
if room_wx_id_list is None:
|
||||||
|
room_wx_id_list = []
|
||||||
|
if user_names is None:
|
||||||
|
user_names = []
|
||||||
|
url = f"{self.base_url}/friend/GetContactDetailsList"
|
||||||
|
params = {"key": self.auth_key}
|
||||||
|
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
logger.error(f"获取联系人详情列表失败: {response.status}")
|
||||||
|
return None
|
||||||
|
result = await response.json()
|
||||||
|
if result.get("Code") == 200 and result.get("Data"):
|
||||||
|
contact_list = result.get("Data", {}).get("contactList", {})
|
||||||
|
return contact_list
|
||||||
|
else:
|
||||||
|
logger.error(f"获取联系人详情列表失败: {result}")
|
||||||
|
return None
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取联系人详情列表时发生错误: {e}")
|
||||||
|
return None
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from PIL import Image as PILImage # 使用别名避免冲突
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.message.components import Image, Plain # Import Image
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||||
|
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
adapter: "WeChatPadProAdapter", # 传递适配器实例
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.message_obj = message_obj # Save the full message object
|
||||||
|
self.adapter = adapter # Save the adapter instance
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for comp in message.chain:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
await self._send_text(session, comp.text)
|
||||||
|
elif isinstance(comp, Image):
|
||||||
|
await self._send_image(session, comp)
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||||
|
b64 = await comp.convert_to_base64()
|
||||||
|
raw = self._validate_base64(b64)
|
||||||
|
b64c = self._compress_image(raw)
|
||||||
|
payload = {
|
||||||
|
"MsgItem": [
|
||||||
|
{"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
url = f"{self.adapter.base_url}/message/SendImageNewMessage"
|
||||||
|
await self._post(session, url, payload)
|
||||||
|
|
||||||
|
async def _send_text(self, session: aiohttp.ClientSession, text: str):
|
||||||
|
if (
|
||||||
|
self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息
|
||||||
|
and self.adapter.settings.get(
|
||||||
|
"reply_with_mention", False
|
||||||
|
) # 检查适配器设置是否启用 reply_with_mention
|
||||||
|
and self.message_obj.sender # 确保有发送者信息
|
||||||
|
and (
|
||||||
|
self.message_obj.sender.user_id or self.message_obj.sender.nickname
|
||||||
|
) # 确保发送者有 ID 或昵称
|
||||||
|
):
|
||||||
|
# 优先使用 nickname,如果没有则使用 user_id
|
||||||
|
mention_text = (
|
||||||
|
self.message_obj.sender.nickname or self.message_obj.sender.user_id
|
||||||
|
)
|
||||||
|
message_text = f"@{mention_text} {text}"
|
||||||
|
# logger.info(f"已添加 @ 信息: {message_text}")
|
||||||
|
else:
|
||||||
|
message_text = text
|
||||||
|
payload = {
|
||||||
|
"MsgItem": [
|
||||||
|
{"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
url = f"{self.adapter.base_url}/message/SendTextMessage"
|
||||||
|
await self._post(session, url, payload)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_base64(b64: str) -> bytes:
|
||||||
|
return base64.b64decode(b64, validate=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compress_image(data: bytes) -> str:
|
||||||
|
img = PILImage.open(io.BytesIO(data))
|
||||||
|
buf = io.BytesIO()
|
||||||
|
if img.format == "JPEG":
|
||||||
|
img.save(buf, "JPEG", quality=80)
|
||||||
|
else:
|
||||||
|
if img.mode in ("RGBA", "P"):
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(buf, "JPEG", quality=80)
|
||||||
|
# logger.info("图片处理完成!!!")
|
||||||
|
return base64.b64encode(buf.getvalue()).decode()
|
||||||
|
|
||||||
|
async def _post(self, session, url, payload):
|
||||||
|
params = {"key": self.adapter.auth_key}
|
||||||
|
try:
|
||||||
|
async with session.post(url, params=params, json=payload) as resp:
|
||||||
|
data = await resp.json()
|
||||||
|
if resp.status != 200 or data.get("Code") != 200:
|
||||||
|
logger.error(f"{url} failed: {resp.status} {data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{url} error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: 添加对其他消息组件类型的处理 (Record, Video, At等)
|
||||||
|
# elif isinstance(component, Record):
|
||||||
|
# pass
|
||||||
|
# elif isinstance(component, Video):
|
||||||
|
# pass
|
||||||
|
# elif isinstance(component, At):
|
||||||
|
# pass
|
||||||
|
# ...
|
||||||
@@ -1,28 +1,33 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
|
||||||
import quart
|
|
||||||
|
|
||||||
|
import quart
|
||||||
|
from requests import Response
|
||||||
|
from wechatpy.enterprise import WeChatClient, parse_message
|
||||||
|
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||||
|
from wechatpy.enterprise.messages import ImageMessage, TextMessage, VoiceMessage
|
||||||
|
from wechatpy.exceptions import InvalidSignatureException
|
||||||
|
from wechatpy.messages import BaseMessage
|
||||||
|
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Image, Plain, Record
|
||||||
from astrbot.api.platform import (
|
from astrbot.api.platform import (
|
||||||
Platform,
|
|
||||||
AstrBotMessage,
|
AstrBotMessage,
|
||||||
MessageMember,
|
MessageMember,
|
||||||
PlatformMetadata,
|
|
||||||
MessageType,
|
MessageType,
|
||||||
|
Platform,
|
||||||
|
PlatformMetadata,
|
||||||
|
register_platform_adapter,
|
||||||
)
|
)
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.api.message_components import Plain, Image, Record
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
|
||||||
from astrbot.api.platform import register_platform_adapter
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from requests import Response
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
|
||||||
from wechatpy.enterprise import WeChatClient
|
|
||||||
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
|
||||||
from wechatpy.exceptions import InvalidSignatureException
|
|
||||||
from wechatpy.enterprise import parse_message
|
|
||||||
from .wecom_event import WecomPlatformEvent
|
from .wecom_event import WecomPlatformEvent
|
||||||
|
from .wecom_kf import WeChatKF
|
||||||
|
from .wecom_kf_message import WeChatKFMessage
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -34,6 +39,7 @@ class WecomServer:
|
|||||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.port = int(config.get("port"))
|
self.port = int(config.get("port"))
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/callback/command", view_func=self.verify, methods=["GET"]
|
"/callback/command", view_func=self.verify, methods=["GET"]
|
||||||
)
|
)
|
||||||
@@ -49,6 +55,7 @@ class WecomServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.callback = None
|
self.callback = None
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def verify(self):
|
async def verify(self):
|
||||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||||
@@ -86,17 +93,17 @@ class WecomServer:
|
|||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
async def start_polling(self):
|
async def start_polling(self):
|
||||||
logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。")
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。"
|
||||||
|
)
|
||||||
await self.server.run_task(
|
await self.server.run_task(
|
||||||
host="0.0.0.0",
|
host=self.callback_server_host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
shutdown_trigger=self.shutdown_trigger_placeholder,
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder(self):
|
async def shutdown_trigger(self):
|
||||||
while not self.event_queue.closed: # noqa: ASYNC110
|
await self.shutdown_event.wait()
|
||||||
await asyncio.sleep(1)
|
|
||||||
logger.info("企业微信 适配器已关闭。")
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("wecom", "wecom 适配器")
|
@register_platform_adapter("wecom", "wecom 适配器")
|
||||||
@@ -129,9 +136,40 @@ class WecomPlatformAdapter(Platform):
|
|||||||
self.config["corpid"].strip(),
|
self.config["corpid"].strip(),
|
||||||
self.config["secret"].strip(),
|
self.config["secret"].strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 微信客服
|
||||||
|
self.kf_name = self.config.get("kf_name", None)
|
||||||
|
if self.kf_name:
|
||||||
|
# inject
|
||||||
|
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||||
|
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||||
|
self.client.kf = self.wechat_kf_api
|
||||||
|
self.client.kf_message = self.wechat_kf_message_api
|
||||||
|
|
||||||
self.client.API_BASE_URL = self.api_base_url
|
self.client.API_BASE_URL = self.api_base_url
|
||||||
|
|
||||||
async def callback(msg):
|
async def callback(msg: BaseMessage):
|
||||||
|
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||||
|
|
||||||
|
def get_latest_msg_item() -> dict | None:
|
||||||
|
token = msg._data["Token"]
|
||||||
|
kfid = msg._data["OpenKfId"]
|
||||||
|
has_more = 1
|
||||||
|
ret = {}
|
||||||
|
while has_more:
|
||||||
|
ret = self.wechat_kf_api.sync_msg(token, kfid)
|
||||||
|
has_more = ret["has_more"]
|
||||||
|
msg_list = ret.get("msg_list", [])
|
||||||
|
if msg_list:
|
||||||
|
return msg_list[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
msg_new = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, get_latest_msg_item
|
||||||
|
)
|
||||||
|
if msg_new:
|
||||||
|
await self.convert_wechat_kf_message(msg_new)
|
||||||
|
return
|
||||||
await self.convert_message(msg)
|
await self.convert_message(msg)
|
||||||
|
|
||||||
self.server.callback = callback
|
self.server.callback = callback
|
||||||
@@ -151,9 +189,39 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if self.kf_name:
|
||||||
|
try:
|
||||||
|
acc_list = (
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self.wechat_kf_api.get_account_list
|
||||||
|
)
|
||||||
|
).get("account_list", [])
|
||||||
|
logger.debug(f"获取到微信客服列表: {str(acc_list)}")
|
||||||
|
for acc in acc_list:
|
||||||
|
name = acc.get("name", None)
|
||||||
|
if name != self.kf_name:
|
||||||
|
continue
|
||||||
|
open_kfid = acc.get("open_kfid", None)
|
||||||
|
if not open_kfid:
|
||||||
|
logger.error("获取微信客服失败,open_kfid 为空。")
|
||||||
|
logger.debug(f"Found open_kfid: {str(open_kfid)}")
|
||||||
|
kf_url = (
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self.wechat_kf_api.add_contact_way,
|
||||||
|
open_kfid,
|
||||||
|
"astrbot_placeholder",
|
||||||
|
)
|
||||||
|
).get("url", "")
|
||||||
|
logger.info(
|
||||||
|
f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
await self.server.start_polling()
|
await self.server.start_polling()
|
||||||
|
|
||||||
async def convert_message(self, msg):
|
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
if msg.type == "text":
|
if msg.type == "text":
|
||||||
assert isinstance(msg, TextMessage)
|
assert isinstance(msg, TextMessage)
|
||||||
@@ -189,14 +257,15 @@ class WecomPlatformAdapter(Platform):
|
|||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None, self.client.media.download, msg.media_id
|
None, self.client.media.download, msg.media_id
|
||||||
)
|
)
|
||||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(resp.content)
|
f.write(resp.content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
|
||||||
audio = AudioSegment.from_file(path)
|
audio = AudioSegment.from_file(path)
|
||||||
audio.export(path_wav, format="wav")
|
audio.export(path_wav, format="wav")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -216,10 +285,43 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.timestamp = msg.time
|
abm.timestamp = msg.time
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
|
else:
|
||||||
|
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"abm: {abm}")
|
logger.info(f"abm: {abm}")
|
||||||
await self.handle_msg(abm)
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||||
|
msgtype = msg.get("msgtype", None)
|
||||||
|
external_userid = msg.get("external_userid", None)
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.raw_message = msg
|
||||||
|
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||||
|
abm.self_id = msg["open_kfid"]
|
||||||
|
abm.sender = MessageMember(external_userid, external_userid)
|
||||||
|
abm.session_id = external_userid
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
|
||||||
|
if msgtype == "text":
|
||||||
|
text = msg.get("text", {}).get("content", "").strip()
|
||||||
|
abm.message = [Plain(text=text)]
|
||||||
|
abm.message_str = text
|
||||||
|
elif msgtype == "image":
|
||||||
|
media_id = msg.get("image", {}).get("media_id", "")
|
||||||
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.client.media.download, media_id
|
||||||
|
)
|
||||||
|
path = f"data/temp/wechat_kf_{media_id}.jpg"
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(resp.content)
|
||||||
|
abm.message = [Image(file=path, url=path)]
|
||||||
|
abm.message_str = "[图片]"
|
||||||
|
else:
|
||||||
|
logger.warning(f"未实现的微信客服消息事件: {msg}")
|
||||||
|
return
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
message_event = WecomPlatformEvent(
|
message_event = WecomPlatformEvent(
|
||||||
message_str=message.message_str,
|
message_str=message.message_str,
|
||||||
@@ -232,3 +334,11 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def get_client(self) -> WeChatClient:
|
def get_client(self) -> WeChatClient:
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.server.shutdown_event.set()
|
||||||
|
try:
|
||||||
|
await self.server.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("企业微信 适配器已被优雅地关闭")
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import asyncio
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
from astrbot.api.message_components import Plain, Image, Record
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
from wechatpy.enterprise import WeChatClient
|
from wechatpy.enterprise import WeChatClient
|
||||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
from .wecom_kf_message import WeChatKFMessage
|
||||||
|
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pydub
|
import pydub
|
||||||
@@ -34,70 +37,158 @@ class WecomPlatformEvent(AstrMessageEvent):
|
|||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def split_plain(self, plain: str) -> list[str]:
|
||||||
|
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plain (str): 要分割的长文本
|
||||||
|
Returns:
|
||||||
|
list[str]: 分割后的文本列表
|
||||||
|
"""
|
||||||
|
if len(plain) <= 2048:
|
||||||
|
return [plain]
|
||||||
|
else:
|
||||||
|
result = []
|
||||||
|
start = 0
|
||||||
|
while start < len(plain):
|
||||||
|
# 剩下的字符串长度<2048时结束
|
||||||
|
if start + 2048 >= len(plain):
|
||||||
|
result.append(plain[start:])
|
||||||
|
break
|
||||||
|
|
||||||
|
# 向前搜索分割标点符号
|
||||||
|
end = min(start + 2048, len(plain))
|
||||||
|
cut_position = end
|
||||||
|
for i in range(end, start, -1):
|
||||||
|
if i < len(plain) and plain[i - 1] in [
|
||||||
|
"。",
|
||||||
|
"!",
|
||||||
|
"?",
|
||||||
|
".",
|
||||||
|
"!",
|
||||||
|
"?",
|
||||||
|
"\n",
|
||||||
|
";",
|
||||||
|
";",
|
||||||
|
]:
|
||||||
|
cut_position = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# 没找到合适的位置分割, 直接切分
|
||||||
|
if cut_position == end and end < len(plain):
|
||||||
|
cut_position = end
|
||||||
|
|
||||||
|
result.append(plain[start:cut_position])
|
||||||
|
start = cut_position
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
message_obj = self.message_obj
|
message_obj = self.message_obj
|
||||||
|
|
||||||
for comp in message.chain:
|
is_wechat_kf = hasattr(self.client, "kf_message")
|
||||||
if isinstance(comp, Plain):
|
if is_wechat_kf:
|
||||||
self.client.message.send_text(
|
# 微信客服
|
||||||
message_obj.self_id, message_obj.session_id, comp.text
|
kf_message_api = getattr(self.client, "kf_message", None)
|
||||||
)
|
if not kf_message_api:
|
||||||
elif isinstance(comp, Image):
|
logger.warning("未找到微信客服发送消息方法。")
|
||||||
img_url = comp.file
|
return
|
||||||
img_path = ""
|
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||||
if img_url.startswith("file:///"):
|
user_id = self.get_sender_id()
|
||||||
img_path = img_url[8:]
|
for comp in message.chain:
|
||||||
elif comp.file and comp.file.startswith("http"):
|
if isinstance(comp, Plain):
|
||||||
img_path = await download_image_by_url(comp.file)
|
# Split long text messages if needed
|
||||||
else:
|
plain_chunks = await self.split_plain(comp.text)
|
||||||
img_path = img_url
|
for chunk in plain_chunks:
|
||||||
|
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||||
|
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||||
|
elif isinstance(comp, Image):
|
||||||
|
img_path = await comp.convert_to_file_path()
|
||||||
|
|
||||||
with open(img_path, "rb") as f:
|
with open(img_path, "rb") as f:
|
||||||
try:
|
try:
|
||||||
response = self.client.media.upload("image", f)
|
response = self.client.media.upload("image", f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"企业微信上传图片失败: {e}")
|
logger.error(f"微信客服上传图片失败: {e}")
|
||||||
await self.send(
|
await self.send(
|
||||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
MessageChain().message(f"微信客服上传图片失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.debug(f"微信客服上传图片返回: {response}")
|
||||||
|
kf_message_api.send_image(
|
||||||
|
user_id,
|
||||||
|
self.get_self_id(),
|
||||||
|
response["media_id"],
|
||||||
)
|
)
|
||||||
return
|
|
||||||
logger.info(f"企业微信上传图片返回: {response}")
|
|
||||||
self.client.message.send_image(
|
|
||||||
message_obj.self_id,
|
|
||||||
message_obj.session_id,
|
|
||||||
response["media_id"],
|
|
||||||
)
|
|
||||||
elif isinstance(comp, Record):
|
|
||||||
record_url = comp.file
|
|
||||||
record_path = ""
|
|
||||||
|
|
||||||
if record_url.startswith("file:///"):
|
|
||||||
record_path = record_url[8:]
|
|
||||||
elif record_url.startswith("http"):
|
|
||||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
|
||||||
else:
|
else:
|
||||||
record_path = record_url
|
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||||
|
else:
|
||||||
# 转成amr
|
# 企业微信应用
|
||||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
for comp in message.chain:
|
||||||
pydub.AudioSegment.from_wav(record_path).export(
|
if isinstance(comp, Plain):
|
||||||
record_path_amr, format="amr"
|
# Split long text messages if needed
|
||||||
)
|
plain_chunks = await self.split_plain(comp.text)
|
||||||
|
for chunk in plain_chunks:
|
||||||
with open(record_path_amr, "rb") as f:
|
self.client.message.send_text(
|
||||||
try:
|
message_obj.self_id, message_obj.session_id, chunk
|
||||||
response = self.client.media.upload("voice", f)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"企业微信上传语音失败: {e}")
|
|
||||||
await self.send(
|
|
||||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
|
||||||
)
|
)
|
||||||
return
|
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||||
logger.info(f"企业微信上传语音返回: {response}")
|
elif isinstance(comp, Image):
|
||||||
self.client.message.send_voice(
|
img_path = await comp.convert_to_file_path()
|
||||||
message_obj.self_id,
|
|
||||||
message_obj.session_id,
|
with open(img_path, "rb") as f:
|
||||||
response["media_id"],
|
try:
|
||||||
|
response = self.client.media.upload("image", f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"企业微信上传图片失败: {e}")
|
||||||
|
await self.send(
|
||||||
|
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.debug(f"企业微信上传图片返回: {response}")
|
||||||
|
self.client.message.send_image(
|
||||||
|
message_obj.self_id,
|
||||||
|
message_obj.session_id,
|
||||||
|
response["media_id"],
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Record):
|
||||||
|
record_path = await comp.convert_to_file_path()
|
||||||
|
# 转成amr
|
||||||
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
|
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
|
||||||
|
pydub.AudioSegment.from_wav(record_path).export(
|
||||||
|
record_path_amr, format="amr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with open(record_path_amr, "rb") as f:
|
||||||
|
try:
|
||||||
|
response = self.client.media.upload("voice", f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"企业微信上传语音失败: {e}")
|
||||||
|
await self.send(
|
||||||
|
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(f"企业微信上传语音返回: {response}")
|
||||||
|
self.client.message.send_voice(
|
||||||
|
message_obj.self_id,
|
||||||
|
message_obj.session_id,
|
||||||
|
response["media_id"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
278
astrbot/core/platform/sources/wecom/wecom_kf.py
Normal file
278
astrbot/core/platform/sources/wecom/wecom_kf.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2014-2020 messense
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from wechatpy.client.api.base import BaseWeChatAPI
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatKF(BaseWeChatAPI):
|
||||||
|
"""
|
||||||
|
微信客服接口
|
||||||
|
|
||||||
|
https://work.weixin.qq.com/api/doc/90000/90135/94670
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sync_msg(self, token, open_kfid, cursor="", limit=1000):
|
||||||
|
"""
|
||||||
|
微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收)
|
||||||
|
、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。
|
||||||
|
支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。
|
||||||
|
|
||||||
|
|
||||||
|
:param token: 回调事件返回的token字段,10分钟内有效;可不填,如果不填接口有严格的频率限制。不多于128字节
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param cursor: 上一次调用时返回的next_cursor,第一次拉取可以不填。不多于64字节
|
||||||
|
:param limit: 期望请求的数据量,默认值和最大值都为1000。
|
||||||
|
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||||
|
return self._post("kf/sync_msg", data=data)
|
||||||
|
|
||||||
|
def get_service_state(self, open_kfid, external_userid):
|
||||||
|
"""
|
||||||
|
获取会话状态
|
||||||
|
|
||||||
|
ID 状态 说明
|
||||||
|
0 未处理 新会话接入。可选择:1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待
|
||||||
|
1 由智能助手接待 可使用API回复消息。可选择转入待接入池或者指定接待人员处理。
|
||||||
|
2 待接入池排队中 在待接入池中排队等待接待人员接入。可选择转为指定人员接待
|
||||||
|
3 由人工接待 人工接待中。可选择结束会话
|
||||||
|
4 已结束 会话已经结束。不允许变更会话状态,等待用户重新发起咨询
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param external_userid: 微信客户的external_userid
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
"external_userid": external_userid,
|
||||||
|
}
|
||||||
|
return self._post("kf/service_state/get", data=data)
|
||||||
|
|
||||||
|
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||||
|
"""
|
||||||
|
变更会话状态
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param external_userid: 微信客户的external_userid
|
||||||
|
:param service_state: 当前的会话状态,状态定义参考概述中的表格
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
"external_userid": external_userid,
|
||||||
|
"service_state": service_state,
|
||||||
|
}
|
||||||
|
if servicer_userid:
|
||||||
|
data["servicer_userid"] = servicer_userid
|
||||||
|
return self._post("kf/service_state/trans", data=data)
|
||||||
|
|
||||||
|
def get_servicer_list(self, open_kfid):
|
||||||
|
"""
|
||||||
|
获取接待人员列表
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
}
|
||||||
|
return self._get("kf/servicer/list", params=data)
|
||||||
|
|
||||||
|
def add_servicer(self, open_kfid, userid_list):
|
||||||
|
"""
|
||||||
|
添加接待人员
|
||||||
|
添加指定客服帐号的接待人员。
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param userid_list: 接待人员userid列表
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
if not isinstance(userid_list, list):
|
||||||
|
userid_list = [userid_list]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
"userid_list": userid_list,
|
||||||
|
}
|
||||||
|
return self._post("kf/servicer/add", data=data)
|
||||||
|
|
||||||
|
def del_servicer(self, open_kfid, userid_list):
|
||||||
|
"""
|
||||||
|
删除接待人员
|
||||||
|
从客服帐号删除接待人员
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param userid_list: 接待人员userid列表
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
if not isinstance(userid_list, list):
|
||||||
|
userid_list = [userid_list]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
"userid_list": userid_list,
|
||||||
|
}
|
||||||
|
return self._post("kf/servicer/del", data=data)
|
||||||
|
|
||||||
|
def batchget_customer(self, external_userid_list):
|
||||||
|
"""
|
||||||
|
客户基本信息获取
|
||||||
|
|
||||||
|
:param external_userid_list: external_userid列表
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
if not isinstance(external_userid_list, list):
|
||||||
|
external_userid_list = [external_userid_list]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"external_userid_list": external_userid_list,
|
||||||
|
}
|
||||||
|
return self._post("kf/customer/batchget", data=data)
|
||||||
|
|
||||||
|
def get_account_list(self):
|
||||||
|
"""
|
||||||
|
获取客服帐号列表
|
||||||
|
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
return self._get("kf/account/list")
|
||||||
|
|
||||||
|
def add_contact_way(self, open_kfid, scene):
|
||||||
|
"""
|
||||||
|
获取客服帐号链接
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param scene: 场景值,字符串类型,由开发者自定义。不多于32字节;字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {"open_kfid": open_kfid, "scene": scene}
|
||||||
|
return self._post("kf/add_contact_way", data=data)
|
||||||
|
|
||||||
|
def get_upgrade_service_config(self):
|
||||||
|
"""
|
||||||
|
获取配置的专员与客户群
|
||||||
|
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
return self._get("kf/customer/get_upgrade_service_config")
|
||||||
|
|
||||||
|
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||||
|
"""
|
||||||
|
为客户升级为专员或客户群服务
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param external_userid: 微信客户的external_userid
|
||||||
|
:param service_type: 表示是升级到专员服务还是客户群服务。1:专员服务。2:客户群服务
|
||||||
|
:param member: 推荐的服务专员,type等于1时有效
|
||||||
|
:param groupchat: 推荐的客户群,type等于2时有效
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
"external_userid": external_userid,
|
||||||
|
"type": service_type,
|
||||||
|
}
|
||||||
|
if service_type == 1:
|
||||||
|
data["member"] = member
|
||||||
|
else:
|
||||||
|
data["groupchat"] = groupchat
|
||||||
|
return self._post("kf/customer/upgrade_service", data=data)
|
||||||
|
|
||||||
|
def cancel_upgrade_service(self, open_kfid, external_userid):
|
||||||
|
"""
|
||||||
|
为客户取消推荐
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param external_userid: 微信客户的external_userid
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = {"open_kfid": open_kfid, "external_userid": external_userid}
|
||||||
|
return self._post("kf/customer/cancel_upgrade_service", data=data)
|
||||||
|
|
||||||
|
def send_msg_on_event(self, code, msgtype, msg_content, msgid=None):
|
||||||
|
"""
|
||||||
|
当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。
|
||||||
|
支持发送消息类型:文本、菜单消息。
|
||||||
|
|
||||||
|
:param code: 事件响应消息对应的code。通过事件回调下发,仅可使用一次。
|
||||||
|
:param msgtype: 消息类型。对不同的msgtype,有相应的结构描述,详见消息类型
|
||||||
|
:param msg_content: 目前支持文本与菜单消息,具体查看文档
|
||||||
|
:param msgid: 消息ID。如果请求参数指定了msgid,则原样返回,否则系统自动生成并返回。不多于32字节;
|
||||||
|
字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = {"code": code, "msgtype": msgtype}
|
||||||
|
if msgid:
|
||||||
|
data["msgid"] = msgid
|
||||||
|
data.update(msg_content)
|
||||||
|
return self._post("kf/send_msg_on_event", data=data)
|
||||||
|
|
||||||
|
def get_corp_statistic(self, start_time, end_time, open_kfid=None):
|
||||||
|
"""
|
||||||
|
获取「客户数据统计」企业汇总数据
|
||||||
|
|
||||||
|
:param start_time: 开始时间
|
||||||
|
:param end_time: 结束时间
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||||
|
return self._post("kf/get_corp_statistic", data=data)
|
||||||
|
|
||||||
|
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||||
|
"""
|
||||||
|
获取「客户数据统计」接待人员明细数据
|
||||||
|
|
||||||
|
:param start_time: 开始时间
|
||||||
|
:param end_time: 结束时间
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param servicer_userid: 接待人员
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
"servicer_userid": servicer_userid,
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
}
|
||||||
|
return self._post("kf/get_servicer_statistic", data=data)
|
||||||
|
|
||||||
|
def account_update(self, open_kfid, name, media_id):
|
||||||
|
"""
|
||||||
|
修改客服账号
|
||||||
|
|
||||||
|
:param open_kfid: 客服帐号ID
|
||||||
|
:param name: 客服名称
|
||||||
|
:param media_id: 客服头像临时素材
|
||||||
|
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
data = {"open_kfid": open_kfid, "name": name, "media_id": media_id}
|
||||||
|
return self._post("kf/account/update", data=data)
|
||||||
159
astrbot/core/platform/sources/wecom/wecom_kf_message.py
Normal file
159
astrbot/core/platform/sources/wecom/wecom_kf_message.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2014-2020 messense
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from optionaldict import optionaldict
|
||||||
|
|
||||||
|
from wechatpy.client.api.base import BaseWeChatAPI
|
||||||
|
|
||||||
|
class WeChatKFMessage(BaseWeChatAPI):
|
||||||
|
"""
|
||||||
|
发送微信客服消息
|
||||||
|
|
||||||
|
https://work.weixin.qq.com/api/doc/90000/90135/94677
|
||||||
|
|
||||||
|
支持:
|
||||||
|
* 文本消息
|
||||||
|
* 图片消息
|
||||||
|
* 语音消息
|
||||||
|
* 视频消息
|
||||||
|
* 文件消息
|
||||||
|
* 图文链接
|
||||||
|
* 小程序
|
||||||
|
* 菜单消息
|
||||||
|
* 地理位置
|
||||||
|
"""
|
||||||
|
|
||||||
|
def send(self, user_id, open_kfid, msgid="", msg=None):
|
||||||
|
"""
|
||||||
|
当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。
|
||||||
|
注意仅当微信客户在主动发送消息给客服后的48小时内,企业可发送消息给客户,最多可发送5条消息;若用户继续发送消息,企业可再次下发消息。
|
||||||
|
支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。
|
||||||
|
|
||||||
|
:param user_id: 指定接收消息的客户UserID
|
||||||
|
:param open_kfid: 指定发送消息的客服帐号ID
|
||||||
|
:param msgid: 指定消息ID
|
||||||
|
:param tag_ids: 标签ID列表。
|
||||||
|
:param msg: 发送消息的 dict 对象
|
||||||
|
:type msg: dict | None
|
||||||
|
:return: 接口调用结果
|
||||||
|
"""
|
||||||
|
msg = msg or {}
|
||||||
|
data = {
|
||||||
|
"touser": user_id,
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
}
|
||||||
|
if msgid:
|
||||||
|
data["msgid"] = msgid
|
||||||
|
data.update(msg)
|
||||||
|
return self._post("kf/send_msg", data=data)
|
||||||
|
|
||||||
|
def send_text(self, user_id, open_kfid, content, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={"msgtype": "text", "text": {"content": content}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_image(self, user_id, open_kfid, media_id, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={"msgtype": "image", "image": {"media_id": media_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_voice(self, user_id, open_kfid, media_id, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={"msgtype": "voice", "voice": {"media_id": media_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_video(self, user_id, open_kfid, media_id, msgid=""):
|
||||||
|
video_data = optionaldict()
|
||||||
|
video_data["media_id"] = media_id
|
||||||
|
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={"msgtype": "video", "video": dict(video_data)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_file(self, user_id, open_kfid, media_id, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={"msgtype": "file", "file": {"media_id": media_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_articles_link(self, user_id, open_kfid, article, msgid=""):
|
||||||
|
articles_data = {
|
||||||
|
"title": article["title"],
|
||||||
|
"desc": article["desc"],
|
||||||
|
"url": article["url"],
|
||||||
|
"thumb_media_id": article["thumb_media_id"],
|
||||||
|
}
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={
|
||||||
|
"msgtype": "msgmenu",
|
||||||
|
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={
|
||||||
|
"msgtype": "location",
|
||||||
|
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||||
|
return self.send(
|
||||||
|
user_id,
|
||||||
|
open_kfid,
|
||||||
|
msgid,
|
||||||
|
msg={
|
||||||
|
"msgtype": "miniprogram",
|
||||||
|
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,286 @@
|
|||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
import quart
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
PlatformMetadata,
|
||||||
|
MessageType,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from astrbot.api.platform import register_platform_adapter
|
||||||
|
from astrbot.core import logger
|
||||||
|
from requests import Response
|
||||||
|
|
||||||
|
from wechatpy.utils import check_signature
|
||||||
|
from wechatpy.crypto import WeChatCrypto
|
||||||
|
from wechatpy import WeChatClient
|
||||||
|
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
|
||||||
|
from wechatpy.exceptions import InvalidSignatureException
|
||||||
|
from wechatpy import parse_message
|
||||||
|
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import override
|
||||||
|
else:
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class WecomServer:
|
||||||
|
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||||
|
self.server = quart.Quart(__name__)
|
||||||
|
self.port = int(config.get("port"))
|
||||||
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
|
self.token = config.get("token")
|
||||||
|
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||||
|
self.appid = config.get("appid")
|
||||||
|
self.server.add_url_rule(
|
||||||
|
"/callback/command", view_func=self.verify, methods=["GET"]
|
||||||
|
)
|
||||||
|
self.server.add_url_rule(
|
||||||
|
"/callback/command", view_func=self.callback_command, methods=["POST"]
|
||||||
|
)
|
||||||
|
self.crypto = WeChatCrypto(self.token, self.encoding_aes_key, self.appid)
|
||||||
|
|
||||||
|
self.event_queue = event_queue
|
||||||
|
|
||||||
|
self.callback = None
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def verify(self):
|
||||||
|
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||||
|
|
||||||
|
args = quart.request.args
|
||||||
|
if not args.get("signature", None):
|
||||||
|
logger.error("未知的响应,请检查回调地址是否填写正确。")
|
||||||
|
return "err"
|
||||||
|
try:
|
||||||
|
check_signature(
|
||||||
|
self.token,
|
||||||
|
args.get("signature"),
|
||||||
|
args.get("timestamp"),
|
||||||
|
args.get("nonce"),
|
||||||
|
)
|
||||||
|
logger.info("验证请求有效性成功。")
|
||||||
|
return args.get("echostr", "empty")
|
||||||
|
except InvalidSignatureException:
|
||||||
|
logger.error("验证请求有效性失败,签名异常,请检查配置。")
|
||||||
|
return "err"
|
||||||
|
|
||||||
|
async def callback_command(self):
|
||||||
|
data = await quart.request.get_data()
|
||||||
|
msg_signature = quart.request.args.get("msg_signature")
|
||||||
|
timestamp = quart.request.args.get("timestamp")
|
||||||
|
nonce = quart.request.args.get("nonce")
|
||||||
|
try:
|
||||||
|
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||||
|
except InvalidSignatureException:
|
||||||
|
logger.error("解密失败,签名异常,请检查配置。")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
msg = parse_message(xml)
|
||||||
|
logger.info(f"解析成功: {msg}")
|
||||||
|
|
||||||
|
if self.callback:
|
||||||
|
result_xml = await self.callback(msg)
|
||||||
|
if not result_xml:
|
||||||
|
return "success"
|
||||||
|
if isinstance(result_xml, str):
|
||||||
|
return result_xml
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
async def start_polling(self):
|
||||||
|
logger.info(
|
||||||
|
f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。"
|
||||||
|
)
|
||||||
|
await self.server.run_task(
|
||||||
|
host=self.callback_server_host,
|
||||||
|
port=self.port,
|
||||||
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown_trigger(self):
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
|
||||||
|
class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self.config = platform_config
|
||||||
|
self.settingss = platform_settings
|
||||||
|
self.client_self_id = uuid.uuid4().hex[:8]
|
||||||
|
self.api_base_url = platform_config.get(
|
||||||
|
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
|
||||||
|
)
|
||||||
|
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||||
|
|
||||||
|
if not self.api_base_url:
|
||||||
|
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||||
|
|
||||||
|
if self.api_base_url.endswith("/"):
|
||||||
|
self.api_base_url = self.api_base_url[:-1]
|
||||||
|
if not self.api_base_url.endswith("/cgi-bin"):
|
||||||
|
self.api_base_url += "/cgi-bin"
|
||||||
|
|
||||||
|
if not self.api_base_url.endswith("/"):
|
||||||
|
self.api_base_url += "/"
|
||||||
|
|
||||||
|
self.server = WecomServer(self._event_queue, self.config)
|
||||||
|
|
||||||
|
self.client = WeChatClient(
|
||||||
|
self.config["appid"].strip(),
|
||||||
|
self.config["secret"].strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client.API_BASE_URL = self.api_base_url
|
||||||
|
|
||||||
|
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||||
|
# msgid -> Future
|
||||||
|
self.wexin_event_workers: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
async def callback(msg: BaseMessage):
|
||||||
|
try:
|
||||||
|
if self.active_send_mode:
|
||||||
|
await self.convert_message(msg, None)
|
||||||
|
else:
|
||||||
|
if msg.id in self.wexin_event_workers:
|
||||||
|
future = self.wexin_event_workers[msg.id]
|
||||||
|
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||||
|
else:
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
self.wexin_event_workers[msg.id] = future
|
||||||
|
await self.convert_message(msg, future)
|
||||||
|
# I love shield so much!
|
||||||
|
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||||
|
logger.debug(f"Got future result: {result}")
|
||||||
|
self.wexin_event_workers.pop(msg.id, None)
|
||||||
|
return result # xml. see weixin_offacc_event.py
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换消息时出现异常: {e}")
|
||||||
|
|
||||||
|
self.server.callback = callback
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
"weixin_official_account",
|
||||||
|
"微信公众平台 适配器",
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def run(self):
|
||||||
|
await self.server.start_polling()
|
||||||
|
|
||||||
|
async def convert_message(
|
||||||
|
self, msg, future: asyncio.Future = None
|
||||||
|
) -> AstrBotMessage | None:
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
if isinstance(msg, TextMessage):
|
||||||
|
abm.message_str = msg.content
|
||||||
|
abm.self_id = str(msg.target)
|
||||||
|
abm.message = [Plain(msg.content)]
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
msg.source,
|
||||||
|
msg.source,
|
||||||
|
)
|
||||||
|
abm.message_id = msg.id
|
||||||
|
abm.timestamp = msg.time
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
elif msg.type == "image":
|
||||||
|
assert isinstance(msg, ImageMessage)
|
||||||
|
abm.message_str = "[图片]"
|
||||||
|
abm.self_id = str(msg.target)
|
||||||
|
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
msg.source,
|
||||||
|
msg.source,
|
||||||
|
)
|
||||||
|
abm.message_id = msg.id
|
||||||
|
abm.timestamp = msg.time
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
elif msg.type == "voice":
|
||||||
|
assert isinstance(msg, VoiceMessage)
|
||||||
|
|
||||||
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.client.media.download, msg.media_id
|
||||||
|
)
|
||||||
|
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(resp.content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||||
|
audio = AudioSegment.from_file(path)
|
||||||
|
audio.export(path_wav, format="wav")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
|
||||||
|
)
|
||||||
|
path_wav = path
|
||||||
|
return
|
||||||
|
|
||||||
|
abm.message_str = ""
|
||||||
|
abm.self_id = str(msg.target)
|
||||||
|
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
msg.source,
|
||||||
|
msg.source,
|
||||||
|
)
|
||||||
|
abm.message_id = msg.id
|
||||||
|
abm.timestamp = msg.time
|
||||||
|
abm.session_id = abm.sender.user_id
|
||||||
|
else:
|
||||||
|
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||||
|
future.set_result(None)
|
||||||
|
return
|
||||||
|
# 很不优雅 :(
|
||||||
|
abm.raw_message = {
|
||||||
|
"message": msg,
|
||||||
|
"future": future,
|
||||||
|
"active_send_mode": self.active_send_mode,
|
||||||
|
}
|
||||||
|
logger.info(f"abm: {abm}")
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
message_event = WeixinOfficialAccountPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.client,
|
||||||
|
)
|
||||||
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
def get_client(self) -> WeChatClient:
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.server.shutdown_event.set()
|
||||||
|
try:
|
||||||
|
await self.server.server.shutdown()
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
logger.info("微信公众平台 适配器已被优雅地关闭")
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
|
from astrbot.api.message_components import Plain, Image, Record
|
||||||
|
from wechatpy import WeChatClient
|
||||||
|
from wechatpy.replies import TextReply, ImageReply, VoiceReply
|
||||||
|
|
||||||
|
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pydub
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
client: WeChatClient,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_with_client(
|
||||||
|
client: WeChatClient, message: MessageChain, user_name: str
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def split_plain(self, plain: str) -> list[str]:
|
||||||
|
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plain (str): 要分割的长文本
|
||||||
|
Returns:
|
||||||
|
list[str]: 分割后的文本列表
|
||||||
|
"""
|
||||||
|
if len(plain) <= 2048:
|
||||||
|
return [plain]
|
||||||
|
else:
|
||||||
|
result = []
|
||||||
|
start = 0
|
||||||
|
while start < len(plain):
|
||||||
|
# 剩下的字符串长度<2048时结束
|
||||||
|
if start + 2048 >= len(plain):
|
||||||
|
result.append(plain[start:])
|
||||||
|
break
|
||||||
|
|
||||||
|
# 向前搜索分割标点符号
|
||||||
|
end = min(start + 2048, len(plain))
|
||||||
|
cut_position = end
|
||||||
|
for i in range(end, start, -1):
|
||||||
|
if i < len(plain) and plain[i - 1] in [
|
||||||
|
"。",
|
||||||
|
"!",
|
||||||
|
"?",
|
||||||
|
".",
|
||||||
|
"!",
|
||||||
|
"?",
|
||||||
|
"\n",
|
||||||
|
";",
|
||||||
|
";",
|
||||||
|
]:
|
||||||
|
cut_position = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# 没找到合适的位置分割, 直接切分
|
||||||
|
if cut_position == end and end < len(plain):
|
||||||
|
cut_position = end
|
||||||
|
|
||||||
|
result.append(plain[start:cut_position])
|
||||||
|
start = cut_position
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
message_obj = self.message_obj
|
||||||
|
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||||
|
for comp in message.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
# Split long text messages if needed
|
||||||
|
plain_chunks = await self.split_plain(comp.text)
|
||||||
|
for chunk in plain_chunks:
|
||||||
|
if active_send_mode:
|
||||||
|
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||||
|
else:
|
||||||
|
reply = TextReply(
|
||||||
|
content=chunk,
|
||||||
|
message=self.message_obj.raw_message["message"],
|
||||||
|
)
|
||||||
|
xml = reply.render()
|
||||||
|
future = self.message_obj.raw_message["future"]
|
||||||
|
assert isinstance(future, asyncio.Future)
|
||||||
|
future.set_result(xml)
|
||||||
|
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||||
|
elif isinstance(comp, Image):
|
||||||
|
img_path = await comp.convert_to_file_path()
|
||||||
|
|
||||||
|
with open(img_path, "rb") as f:
|
||||||
|
try:
|
||||||
|
response = self.client.media.upload("image", f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"微信公众平台上传图片失败: {e}")
|
||||||
|
await self.send(
|
||||||
|
MessageChain().message(f"微信公众平台上传图片失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||||
|
|
||||||
|
if active_send_mode:
|
||||||
|
self.client.message.send_image(
|
||||||
|
message_obj.sender.user_id,
|
||||||
|
response["media_id"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reply = ImageReply(
|
||||||
|
media_id=response["media_id"],
|
||||||
|
message=self.message_obj.raw_message["message"],
|
||||||
|
)
|
||||||
|
xml = reply.render()
|
||||||
|
future = self.message_obj.raw_message["future"]
|
||||||
|
assert isinstance(future, asyncio.Future)
|
||||||
|
future.set_result(xml)
|
||||||
|
|
||||||
|
elif isinstance(comp, Record):
|
||||||
|
record_path = await comp.convert_to_file_path()
|
||||||
|
# 转成amr
|
||||||
|
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||||
|
pydub.AudioSegment.from_wav(record_path).export(
|
||||||
|
record_path_amr, format="amr"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(record_path_amr, "rb") as f:
|
||||||
|
try:
|
||||||
|
response = self.client.media.upload("voice", f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"微信公众平台上传语音失败: {e}")
|
||||||
|
await self.send(
|
||||||
|
MessageChain().message(f"微信公众平台上传语音失败: {e}")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
if active_send_mode:
|
||||||
|
self.client.message.send_voice(
|
||||||
|
message_obj.sender.user_id,
|
||||||
|
response["media_id"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reply = VoiceReply(
|
||||||
|
media_id=response["media_id"],
|
||||||
|
message=self.message_obj.raw_message["message"],
|
||||||
|
)
|
||||||
|
xml = reply.render()
|
||||||
|
future = self.message_obj.raw_message["future"]
|
||||||
|
assert isinstance(future, asyncio.Future)
|
||||||
|
future.set_result(xml)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||||
|
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from .provider import Provider, Personality, STTProvider
|
from .provider import Provider, Personality, STTProvider
|
||||||
|
|
||||||
from .entites import ProviderMetaData
|
from .entities import ProviderMetaData
|
||||||
|
|
||||||
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
|
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
|
||||||
|
|||||||
@@ -1,67 +1,19 @@
|
|||||||
import enum
|
from astrbot.core.provider.entities import (
|
||||||
from dataclasses import dataclass, field
|
ProviderRequest,
|
||||||
from typing import List, Dict, Type
|
ProviderType,
|
||||||
from .func_tool_manager import FuncCall
|
ProviderMetaData,
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
ToolCallsResult,
|
||||||
from astrbot.core.db.po import Conversation
|
AssistantMessageSegment,
|
||||||
|
ToolCallMessageSegment,
|
||||||
|
LLMResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
class ProviderType(enum.Enum):
|
"ProviderRequest",
|
||||||
CHAT_COMPLETION = "chat_completion"
|
"ProviderType",
|
||||||
SPEECH_TO_TEXT = "speech_to_text"
|
"ProviderMetaData",
|
||||||
TEXT_TO_SPEECH = "text_to_speech"
|
"ToolCallsResult",
|
||||||
|
"AssistantMessageSegment",
|
||||||
|
"ToolCallMessageSegment",
|
||||||
@dataclass
|
"LLMResponse",
|
||||||
class ProviderMetaData:
|
]
|
||||||
type: str
|
|
||||||
"""提供商适配器名称,如 openai, ollama"""
|
|
||||||
desc: str = ""
|
|
||||||
"""提供商适配器描述."""
|
|
||||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
|
||||||
cls_type: Type = None
|
|
||||||
|
|
||||||
default_config_tmpl: dict = None
|
|
||||||
"""平台的默认配置模板"""
|
|
||||||
provider_display_name: str = None
|
|
||||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProviderRequest:
|
|
||||||
prompt: str
|
|
||||||
"""提示词"""
|
|
||||||
session_id: str = ""
|
|
||||||
"""会话 ID"""
|
|
||||||
image_urls: List[str] = None
|
|
||||||
"""图片 URL 列表"""
|
|
||||||
func_tool: FuncCall = None
|
|
||||||
"""工具"""
|
|
||||||
contexts: List = None
|
|
||||||
"""上下文。格式与 openai 的上下文格式一致:
|
|
||||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
|
||||||
"""
|
|
||||||
system_prompt: str = ""
|
|
||||||
"""系统提示词"""
|
|
||||||
conversation: Conversation = None
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt.strip()})"
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMResponse:
|
|
||||||
role: str
|
|
||||||
"""角色, assistant, tool, err"""
|
|
||||||
completion_text: str = ""
|
|
||||||
"""LLM 返回的文本"""
|
|
||||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
|
||||||
"""工具调用参数"""
|
|
||||||
tools_call_name: List[str] = field(default_factory=list)
|
|
||||||
"""工具调用名称"""
|
|
||||||
|
|
||||||
raw_completion: ChatCompletion = None
|
|
||||||
_new_record: Dict[str, any] = None
|
|
||||||
|
|||||||
284
astrbot/core/provider/entities.py
Normal file
284
astrbot/core/provider/entities.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
import enum
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
from astrbot import logger
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Type
|
||||||
|
from .func_tool_manager import FuncCall
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from astrbot.core.db.po import Conversation
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
import astrbot.core.message.components as Comp
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(enum.Enum):
|
||||||
|
CHAT_COMPLETION = "chat_completion"
|
||||||
|
SPEECH_TO_TEXT = "speech_to_text"
|
||||||
|
TEXT_TO_SPEECH = "text_to_speech"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderMetaData:
|
||||||
|
type: str
|
||||||
|
"""提供商适配器名称,如 openai, ollama"""
|
||||||
|
desc: str = ""
|
||||||
|
"""提供商适配器描述."""
|
||||||
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||||
|
cls_type: Type = None
|
||||||
|
|
||||||
|
default_config_tmpl: dict = None
|
||||||
|
"""平台的默认配置模板"""
|
||||||
|
provider_display_name: str = None
|
||||||
|
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallMessageSegment:
|
||||||
|
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
|
tool_call_id: str
|
||||||
|
content: str
|
||||||
|
role: str = "tool"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"tool_call_id": self.tool_call_id,
|
||||||
|
"content": self.content,
|
||||||
|
"role": self.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistantMessageSegment:
|
||||||
|
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
|
content: str = None
|
||||||
|
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||||
|
role: str = "assistant"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
ret = {
|
||||||
|
"role": self.role,
|
||||||
|
}
|
||||||
|
if self.content:
|
||||||
|
ret["content"] = self.content
|
||||||
|
elif self.tool_calls:
|
||||||
|
ret["tool_calls"] = self.tool_calls
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallsResult:
|
||||||
|
"""工具调用结果"""
|
||||||
|
|
||||||
|
tool_calls_info: AssistantMessageSegment
|
||||||
|
"""函数调用的信息"""
|
||||||
|
tool_calls_result: List[ToolCallMessageSegment]
|
||||||
|
"""函数调用的结果"""
|
||||||
|
|
||||||
|
def to_openai_messages(self) -> List[Dict]:
|
||||||
|
ret = [
|
||||||
|
self.tool_calls_info.to_dict(),
|
||||||
|
*[item.to_dict() for item in self.tool_calls_result],
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderRequest:
|
||||||
|
prompt: str
|
||||||
|
"""提示词"""
|
||||||
|
session_id: str = ""
|
||||||
|
"""会话 ID"""
|
||||||
|
image_urls: List[str] = None
|
||||||
|
"""图片 URL 列表"""
|
||||||
|
func_tool: FuncCall = None
|
||||||
|
"""可用的函数工具"""
|
||||||
|
contexts: List = None
|
||||||
|
"""上下文。格式与 openai 的上下文格式一致:
|
||||||
|
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||||
|
"""
|
||||||
|
system_prompt: str = ""
|
||||||
|
"""系统提示词"""
|
||||||
|
conversation: Conversation = None
|
||||||
|
|
||||||
|
tool_calls_result: ToolCallsResult = None
|
||||||
|
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def _print_friendly_context(self):
|
||||||
|
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||||
|
if not self.contexts:
|
||||||
|
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
|
||||||
|
|
||||||
|
result_parts = []
|
||||||
|
|
||||||
|
for ctx in self.contexts:
|
||||||
|
role = ctx.get("role", "unknown")
|
||||||
|
content = ctx.get("content", "")
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
result_parts.append(f"{role}: {content}")
|
||||||
|
elif isinstance(content, list):
|
||||||
|
msg_parts = []
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
item_type = item.get("type", "")
|
||||||
|
|
||||||
|
if item_type == "text":
|
||||||
|
msg_parts.append(item.get("text", ""))
|
||||||
|
elif item_type == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
|
||||||
|
if image_count > 0:
|
||||||
|
if msg_parts:
|
||||||
|
msg_parts.append(f"[+{image_count} images]")
|
||||||
|
else:
|
||||||
|
msg_parts.append(f"[{image_count} images]")
|
||||||
|
|
||||||
|
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
||||||
|
|
||||||
|
return result_parts
|
||||||
|
|
||||||
|
async def assemble_context(self) -> Dict:
|
||||||
|
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||||
|
if self.image_urls:
|
||||||
|
user_content = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for image_url in self.image_urls:
|
||||||
|
if image_url.startswith("http"):
|
||||||
|
image_path = await download_image_by_url(image_url)
|
||||||
|
image_data = await self._encode_image_bs64(image_path)
|
||||||
|
elif image_url.startswith("file:///"):
|
||||||
|
image_path = image_url.replace("file:///", "")
|
||||||
|
image_data = await self._encode_image_bs64(image_path)
|
||||||
|
else:
|
||||||
|
image_data = await self._encode_image_bs64(image_url)
|
||||||
|
if not image_data:
|
||||||
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||||
|
continue
|
||||||
|
user_content["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": image_data}}
|
||||||
|
)
|
||||||
|
return user_content
|
||||||
|
else:
|
||||||
|
return {"role": "user", "content": self.prompt}
|
||||||
|
|
||||||
|
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||||
|
"""将图片转换为 base64"""
|
||||||
|
if image_url.startswith("base64://"):
|
||||||
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||||
|
with open(image_url, "rb") as f:
|
||||||
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
return "data:image/jpeg;base64," + image_bs64
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
role: str
|
||||||
|
"""角色, assistant, tool, err"""
|
||||||
|
result_chain: MessageChain = None
|
||||||
|
"""返回的消息链"""
|
||||||
|
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||||
|
"""工具调用参数"""
|
||||||
|
tools_call_name: List[str] = field(default_factory=list)
|
||||||
|
"""工具调用名称"""
|
||||||
|
tools_call_ids: List[str] = field(default_factory=list)
|
||||||
|
"""工具调用 ID"""
|
||||||
|
|
||||||
|
raw_completion: ChatCompletion = None
|
||||||
|
_new_record: Dict[str, any] = None
|
||||||
|
|
||||||
|
_completion_text: str = ""
|
||||||
|
|
||||||
|
is_chunk: bool = False
|
||||||
|
"""是否是流式输出的单个 Chunk"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: str,
|
||||||
|
completion_text: str = "",
|
||||||
|
result_chain: MessageChain = None,
|
||||||
|
tools_call_args: List[Dict[str, any]] = None,
|
||||||
|
tools_call_name: List[str] = None,
|
||||||
|
tools_call_ids: List[str] = None,
|
||||||
|
raw_completion: ChatCompletion = None,
|
||||||
|
_new_record: Dict[str, any] = None,
|
||||||
|
is_chunk: bool = False,
|
||||||
|
):
|
||||||
|
"""初始化 LLMResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): 角色, assistant, tool, err
|
||||||
|
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
|
||||||
|
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
|
||||||
|
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
|
||||||
|
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
||||||
|
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||||
|
"""
|
||||||
|
if tools_call_args is None:
|
||||||
|
tools_call_args = []
|
||||||
|
if tools_call_name is None:
|
||||||
|
tools_call_name = []
|
||||||
|
if tools_call_ids is None:
|
||||||
|
tools_call_ids = []
|
||||||
|
|
||||||
|
self.role = role
|
||||||
|
self.completion_text = completion_text
|
||||||
|
self.result_chain = result_chain
|
||||||
|
self.tools_call_args = tools_call_args
|
||||||
|
self.tools_call_name = tools_call_name
|
||||||
|
self.tools_call_ids = tools_call_ids
|
||||||
|
self.raw_completion = raw_completion
|
||||||
|
self._new_record = _new_record
|
||||||
|
self.is_chunk = is_chunk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_text(self):
|
||||||
|
if self.result_chain:
|
||||||
|
return self.result_chain.get_plain_text()
|
||||||
|
return self._completion_text
|
||||||
|
|
||||||
|
@completion_text.setter
|
||||||
|
def completion_text(self, value):
|
||||||
|
if self.result_chain:
|
||||||
|
self.result_chain.chain = [
|
||||||
|
comp
|
||||||
|
for comp in self.result_chain.chain
|
||||||
|
if not isinstance(comp, Comp.Plain)
|
||||||
|
] # 清空 Plain 组件
|
||||||
|
self.result_chain.chain.insert(0, Comp.Plain(value))
|
||||||
|
else:
|
||||||
|
self._completion_text = value
|
||||||
|
|
||||||
|
def to_openai_tool_calls(self) -> List[Dict]:
|
||||||
|
"""将工具调用信息转换为 OpenAI 格式"""
|
||||||
|
ret = []
|
||||||
|
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
||||||
|
ret.append(
|
||||||
|
{
|
||||||
|
"id": self.tools_call_ids[idx],
|
||||||
|
"function": {
|
||||||
|
"name": self.tools_call_name[idx],
|
||||||
|
"arguments": json.dumps(tool_call_arg),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ret
|
||||||
@@ -1,7 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Dict, List, Awaitable
|
import os
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from typing import Dict, List, Awaitable, Literal, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.utils.log_pipe import LogPipe
|
||||||
|
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mcp
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
logger.warning(
|
||||||
|
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||||
|
|
||||||
|
SUPPORTED_TYPES = [
|
||||||
|
"string",
|
||||||
|
"number",
|
||||||
|
"object",
|
||||||
|
"array",
|
||||||
|
"boolean",
|
||||||
|
] # json schema 支持的数据类型
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -13,28 +48,162 @@ class FuncTool:
|
|||||||
name: str
|
name: str
|
||||||
parameters: Dict
|
parameters: Dict
|
||||||
description: str
|
description: str
|
||||||
handler: Awaitable
|
handler: Awaitable = None
|
||||||
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||||
|
handler_module_path: str = None
|
||||||
|
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||||
|
|
||||||
|
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||||
|
"""
|
||||||
active: bool = True
|
active: bool = True
|
||||||
"""是否激活"""
|
"""是否激活"""
|
||||||
|
|
||||||
|
origin: Literal["local", "mcp"] = "local"
|
||||||
|
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||||
|
|
||||||
|
# MCP 相关字段
|
||||||
|
mcp_server_name: str = None
|
||||||
|
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||||
|
mcp_client: MCPClient = None
|
||||||
|
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}), active={self.active})"
|
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||||
|
|
||||||
|
async def execute(self, **args) -> Any:
|
||||||
|
"""执行函数调用"""
|
||||||
|
if self.origin == "local":
|
||||||
|
if not self.handler:
|
||||||
|
raise Exception(f"Local function {self.name} has no handler")
|
||||||
|
return await self.handler(**args)
|
||||||
|
elif self.origin == "mcp":
|
||||||
|
if not self.mcp_client or not self.mcp_client.session:
|
||||||
|
raise Exception(f"MCP client for {self.name} is not available")
|
||||||
|
# 使用name属性而不是额外的mcp_tool_name
|
||||||
|
if ":" in self.name:
|
||||||
|
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
|
||||||
|
actual_tool_name = self.name.split(":")[-1]
|
||||||
|
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||||
|
else:
|
||||||
|
return await self.mcp_client.session.call_tool(self.name, args)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown function origin: {self.origin}")
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_TYPES = [
|
class MCPClient:
|
||||||
"string",
|
def __init__(self):
|
||||||
"number",
|
# Initialize session and client objects
|
||||||
"object",
|
self.session: Optional[mcp.ClientSession] = None
|
||||||
"array",
|
self.exit_stack = AsyncExitStack()
|
||||||
"boolean",
|
|
||||||
] # json schema 支持的数据类型
|
self.name = None
|
||||||
|
self.active: bool = True
|
||||||
|
self.tools: List[mcp.Tool] = []
|
||||||
|
self.server_errlogs: List[str] = []
|
||||||
|
|
||||||
|
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||||
|
"""连接到 MCP 服务器
|
||||||
|
|
||||||
|
如果 `url` 参数存在:
|
||||||
|
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||||
|
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||||
|
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||||
|
"""
|
||||||
|
cfg = mcp_server_config.copy()
|
||||||
|
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||||
|
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||||
|
cfg = cfg["mcpServers"][key_0]
|
||||||
|
cfg.pop("active", None) # Remove active flag from config
|
||||||
|
|
||||||
|
if "url" in cfg:
|
||||||
|
is_sse = True
|
||||||
|
if cfg.get("transport") == "streamable_http":
|
||||||
|
is_sse = False
|
||||||
|
if is_sse:
|
||||||
|
# SSE transport method
|
||||||
|
self._streams_context = sse_client(
|
||||||
|
url=cfg["url"],
|
||||||
|
headers=cfg.get("headers", {}),
|
||||||
|
timeout=cfg.get("timeout", 5),
|
||||||
|
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||||
|
)
|
||||||
|
streams = await self._streams_context.__aenter__()
|
||||||
|
|
||||||
|
# Create a new client session
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.ClientSession(*streams)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||||
|
sse_read_timeout = timedelta(
|
||||||
|
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||||
|
)
|
||||||
|
self._streams_context = streamablehttp_client(
|
||||||
|
url=cfg["url"],
|
||||||
|
headers=cfg.get("headers", {}),
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||||
|
)
|
||||||
|
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||||
|
|
||||||
|
# Create a new client session
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
server_params = mcp.StdioServerParameters(
|
||||||
|
**cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def callback(msg: str):
|
||||||
|
# 处理 MCP 服务的错误日志
|
||||||
|
self.server_errlogs.append(msg)
|
||||||
|
|
||||||
|
stdio_transport = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.stdio_client(
|
||||||
|
server_params,
|
||||||
|
errlog=LogPipe(
|
||||||
|
level=logging.ERROR,
|
||||||
|
logger=logger,
|
||||||
|
identifier=f"MCPServer-{name}",
|
||||||
|
callback=callback,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new client session
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
mcp.ClientSession(*stdio_transport)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.session.initialize()
|
||||||
|
|
||||||
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
|
"""List all tools from the server and save them to self.tools"""
|
||||||
|
response = await self.session.list_tools()
|
||||||
|
logger.debug(f"MCP server {self.name} list tools response: {response}")
|
||||||
|
self.tools = response.tools
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""Clean up resources"""
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
|
||||||
|
|
||||||
class FuncCall:
|
class FuncCall:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.func_list: List[FuncTool] = []
|
self.func_list: List[FuncTool] = []
|
||||||
|
"""内部加载的 func tools"""
|
||||||
|
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||||
|
"""MCP 服务列表"""
|
||||||
|
self.mcp_service_queue = asyncio.Queue()
|
||||||
|
"""用于外部控制 MCP 服务的启停"""
|
||||||
|
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
return len(self.func_list) == 0
|
return len(self.func_list) == 0
|
||||||
@@ -46,14 +215,16 @@ class FuncCall:
|
|||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Awaitable,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""添加函数调用工具
|
||||||
为函数调用(function-calling / tools-use)添加工具。
|
|
||||||
|
|
||||||
@param name: 函数名
|
@param name: 函数名
|
||||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||||
@param desc: 函数描述
|
@param desc: 函数描述
|
||||||
@param func_obj: 处理函数
|
@param func_obj: 处理函数
|
||||||
"""
|
"""
|
||||||
|
# check if the tool has been added before
|
||||||
|
self.remove_func(name)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
"properties": {},
|
"properties": {},
|
||||||
@@ -70,13 +241,14 @@ class FuncCall:
|
|||||||
handler=handler,
|
handler=handler,
|
||||||
)
|
)
|
||||||
self.func_list.append(_func)
|
self.func_list.append(_func)
|
||||||
|
logger.info(f"添加函数调用工具: {name}")
|
||||||
|
|
||||||
def remove_func(self, name: str) -> None:
|
def remove_func(self, name: str) -> None:
|
||||||
"""
|
"""
|
||||||
删除一个函数调用工具。
|
删除一个函数调用工具。
|
||||||
"""
|
"""
|
||||||
for i, f in enumerate(self.func_list):
|
for i, f in enumerate(self.func_list):
|
||||||
if f["name"] == name:
|
if f.name == name:
|
||||||
self.func_list.pop(i)
|
self.func_list.pop(i)
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -86,24 +258,195 @@ class FuncCall:
|
|||||||
return f
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_func_desc_openai_style(self) -> list:
|
async def _init_mcp_clients(self) -> None:
|
||||||
|
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"weather": {
|
||||||
|
"command": "uv",
|
||||||
|
"args": [
|
||||||
|
"--directory",
|
||||||
|
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
|
||||||
|
"run",
|
||||||
|
"weather.py"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
data_dir = get_astrbot_data_path()
|
||||||
|
|
||||||
|
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||||
|
if not os.path.exists(mcp_json_file):
|
||||||
|
# 配置文件不存在错误处理
|
||||||
|
with open(mcp_json_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||||
|
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
mcp_server_json_obj: Dict[str, Dict] = json.load(
|
||||||
|
open(mcp_json_file, "r", encoding="utf-8")
|
||||||
|
)["mcpServers"]
|
||||||
|
|
||||||
|
for name in mcp_server_json_obj.keys():
|
||||||
|
cfg = mcp_server_json_obj[name]
|
||||||
|
if cfg.get("active", True):
|
||||||
|
event = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(name, cfg, event)
|
||||||
|
)
|
||||||
|
self.mcp_client_event[name] = event
|
||||||
|
|
||||||
|
async def mcp_service_selector(self):
|
||||||
|
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
|
||||||
|
|
||||||
|
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
|
||||||
|
|
||||||
|
{"type": "init"} 初始化所有MCP客户端
|
||||||
|
|
||||||
|
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
|
||||||
|
|
||||||
|
{"type": "terminate"} 终止所有MCP客户端
|
||||||
|
|
||||||
|
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
data = await self.mcp_service_queue.get()
|
||||||
|
if data["type"] == "init":
|
||||||
|
if "name" in data:
|
||||||
|
event = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
self._init_mcp_client_task_wrapper(
|
||||||
|
data["name"], data["cfg"], event
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.mcp_client_event[data["name"]] = event
|
||||||
|
else:
|
||||||
|
await self._init_mcp_clients()
|
||||||
|
elif data["type"] == "terminate":
|
||||||
|
if "name" in data:
|
||||||
|
# await self._terminate_mcp_client(data["name"])
|
||||||
|
if data["name"] in self.mcp_client_event:
|
||||||
|
self.mcp_client_event[data["name"]].set()
|
||||||
|
self.mcp_client_event.pop(data["name"], None)
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (
|
||||||
|
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
for name in self.mcp_client_dict.keys():
|
||||||
|
# await self._terminate_mcp_client(name)
|
||||||
|
# self.mcp_client_event[name].set()
|
||||||
|
if name in self.mcp_client_event:
|
||||||
|
self.mcp_client_event[name].set()
|
||||||
|
self.mcp_client_event.pop(name, None)
|
||||||
|
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||||
|
|
||||||
|
async def _init_mcp_client_task_wrapper(
|
||||||
|
self, name: str, cfg: dict, event: asyncio.Event
|
||||||
|
) -> None:
|
||||||
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
|
try:
|
||||||
|
await self._init_mcp_client(name, cfg)
|
||||||
|
await event.wait()
|
||||||
|
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||||
|
|
||||||
|
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||||
|
"""初始化单个MCP客户端"""
|
||||||
|
try:
|
||||||
|
# 先清理之前的客户端,如果存在
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
|
||||||
|
mcp_client = MCPClient()
|
||||||
|
mcp_client.name = name
|
||||||
|
self.mcp_client_dict[name] = mcp_client
|
||||||
|
await mcp_client.connect_to_server(config, name)
|
||||||
|
tools_res = await mcp_client.list_tools_and_save()
|
||||||
|
tool_names = [tool.name for tool in tools_res.tools]
|
||||||
|
|
||||||
|
# 移除该MCP服务之前的工具(如有)
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
||||||
|
for tool in mcp_client.tools:
|
||||||
|
func_tool = FuncTool(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=tool.inputSchema,
|
||||||
|
description=tool.description,
|
||||||
|
origin="mcp",
|
||||||
|
mcp_server_name=name,
|
||||||
|
mcp_client=mcp_client,
|
||||||
|
)
|
||||||
|
self.func_list.append(func_tool)
|
||||||
|
|
||||||
|
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||||
|
# 发生错误时确保客户端被清理
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _terminate_mcp_client(self, name: str) -> None:
|
||||||
|
"""关闭并清理MCP客户端"""
|
||||||
|
if name in self.mcp_client_dict:
|
||||||
|
try:
|
||||||
|
# 关闭MCP连接
|
||||||
|
await self.mcp_client_dict[name].cleanup()
|
||||||
|
del self.mcp_client_dict[name]
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||||
|
# 移除关联的FuncTool
|
||||||
|
self.func_list = [
|
||||||
|
f
|
||||||
|
for f in self.func_list
|
||||||
|
if not (f.origin == "mcp" and f.mcp_server_name == name)
|
||||||
|
]
|
||||||
|
logger.info(f"已关闭 MCP 服务 {name}")
|
||||||
|
|
||||||
|
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||||
"""
|
"""
|
||||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||||
"""
|
"""
|
||||||
_l = []
|
_l = []
|
||||||
|
# 处理所有工具(包括本地和MCP工具)
|
||||||
for f in self.func_list:
|
for f in self.func_list:
|
||||||
if not f.active:
|
if not f.active:
|
||||||
continue
|
continue
|
||||||
_l.append(
|
func_ = {
|
||||||
{
|
"type": "function",
|
||||||
"type": "function",
|
"function": {
|
||||||
"function": {
|
"name": f.name,
|
||||||
"name": f.name,
|
# "parameters": f.parameters,
|
||||||
"parameters": f.parameters,
|
"description": f.description,
|
||||||
"description": f.description,
|
},
|
||||||
},
|
}
|
||||||
}
|
func_["function"]["parameters"] = f.parameters
|
||||||
)
|
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||||
|
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||||
|
del func_["function"]["parameters"]
|
||||||
|
_l.append(func_)
|
||||||
return _l
|
return _l
|
||||||
|
|
||||||
def get_func_desc_anthropic_style(self) -> list:
|
def get_func_desc_anthropic_style(self) -> list:
|
||||||
@@ -129,22 +472,86 @@ class FuncCall:
|
|||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def get_func_desc_google_genai_style(self) -> Dict:
|
def get_func_desc_google_genai_style(self) -> dict:
|
||||||
|
"""
|
||||||
|
获得 Google GenAI API 风格的**已经激活**的工具描述
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Gemini API 支持的数据类型和格式
|
||||||
|
supported_types = {
|
||||||
|
"string",
|
||||||
|
"number",
|
||||||
|
"integer",
|
||||||
|
"boolean",
|
||||||
|
"array",
|
||||||
|
"object",
|
||||||
|
"null",
|
||||||
|
}
|
||||||
|
supported_formats = {
|
||||||
|
"string": {"enum", "date-time"},
|
||||||
|
"integer": {"int32", "int64"},
|
||||||
|
"number": {"float", "double"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert_schema(schema: dict) -> dict:
|
||||||
|
"""转换 schema 为 Gemini API 格式"""
|
||||||
|
|
||||||
|
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||||
|
if "anyOf" in schema:
|
||||||
|
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if "type" in schema and schema["type"] in supported_types:
|
||||||
|
result["type"] = schema["type"]
|
||||||
|
if "format" in schema and schema["format"] in supported_formats.get(
|
||||||
|
result["type"], set()
|
||||||
|
):
|
||||||
|
result["format"] = schema["format"]
|
||||||
|
else:
|
||||||
|
# 暂时指定默认为null
|
||||||
|
result["type"] = "null"
|
||||||
|
|
||||||
|
support_fields = {
|
||||||
|
"title",
|
||||||
|
"description",
|
||||||
|
"enum",
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"maxItems",
|
||||||
|
"minItems",
|
||||||
|
"nullable",
|
||||||
|
"required",
|
||||||
|
}
|
||||||
|
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||||
|
|
||||||
|
if "properties" in schema:
|
||||||
|
properties = {}
|
||||||
|
for key, value in schema["properties"].items():
|
||||||
|
prop_value = convert_schema(value)
|
||||||
|
if "default" in prop_value:
|
||||||
|
del prop_value["default"]
|
||||||
|
properties[key] = prop_value
|
||||||
|
|
||||||
|
if properties: # 只在有非空属性时添加
|
||||||
|
result["properties"] = properties
|
||||||
|
|
||||||
|
if "items" in schema:
|
||||||
|
result["items"] = convert_schema(schema["items"])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"name": f.name,
|
||||||
|
"description": f.description,
|
||||||
|
**({"parameters": convert_schema(f.parameters)}),
|
||||||
|
}
|
||||||
|
for f in self.func_list
|
||||||
|
if f.active
|
||||||
|
]
|
||||||
|
|
||||||
declarations = {}
|
declarations = {}
|
||||||
tools = []
|
|
||||||
for f in self.func_list:
|
|
||||||
if not f.active:
|
|
||||||
continue
|
|
||||||
|
|
||||||
func_declaration = {"name": f.name, "description": f.description}
|
|
||||||
|
|
||||||
# 检查并添加非空的properties参数
|
|
||||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
|
||||||
if params.get("properties", {}):
|
|
||||||
func_declaration["parameters"] = params
|
|
||||||
|
|
||||||
tools.append(func_declaration)
|
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
declarations["function_declarations"] = tools
|
declarations["function_declarations"] = tools
|
||||||
return declarations
|
return declarations
|
||||||
@@ -156,9 +563,9 @@ class FuncCall:
|
|||||||
continue
|
continue
|
||||||
_l.append(
|
_l.append(
|
||||||
{
|
{
|
||||||
"name": f["name"],
|
"name": f.name,
|
||||||
"parameters": f["parameters"],
|
"parameters": f.parameters,
|
||||||
"description": f["description"],
|
"description": f.description,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||||
@@ -208,14 +615,11 @@ class FuncCall:
|
|||||||
func_name = tool["name"]
|
func_name = tool["name"]
|
||||||
args = tool["args"]
|
args = tool["args"]
|
||||||
# 调用函数
|
# 调用函数
|
||||||
tool_callable = None
|
func_tool = self.get_func(func_name)
|
||||||
for func in self.func_list:
|
if not func_tool:
|
||||||
if func.name == func_name:
|
|
||||||
tool_callable = func.star_handler_metadata.handler
|
|
||||||
break
|
|
||||||
if not tool_callable:
|
|
||||||
raise Exception(f"Request function {func_name} not found.")
|
raise Exception(f"Request function {func_name} not found.")
|
||||||
ret = await tool_callable(**args)
|
|
||||||
|
ret = await func_tool.execute(**args)
|
||||||
if ret:
|
if ret:
|
||||||
tool_call_result.append(str(ret))
|
tool_call_result.append(str(ret))
|
||||||
return tool_call_result, True
|
return tool_call_result, True
|
||||||
@@ -225,3 +629,8 @@ class FuncCall:
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.func_list)
|
return str(self.func_list)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
for name in self.mcp_client_dict.keys():
|
||||||
|
await self._terminate_mcp_client(name)
|
||||||
|
logger.debug(f"清理 MCP 客户端 {name} 资源")
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user